File size: 3,837 Bytes
b24d496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
from typing import List, Dict
import pandas as pd
from pathlib import Path
from tqdm import tqdm

from common.constants import COLUMN_DOC_NAME
from common.constants import COLUMN_EMBEDDING
from common.constants import COLUMN_EMBEDDING_FULL
from common.constants import COLUMN_LABELS_STR
from common.constants import COLUMN_NAMES
from common.constants import COLUMN_TABLE_NAME
from common.constants import COLUMN_TEXT
from common.constants import DEVICE
from common.constants import DO_NORMALIZATION
from common.constants import COLUMN_TYPE_DOC_MAP
from components.embedding_extraction import EmbeddingExtractor


def get_label(unique_names: List) -> Dict[str, int]:
    """
    Генерирует метки исходя из количества уникальных названий файлов.
    Args:
        unique_names: Список уникальных наименований файлов.

    Returns:
        Возвращает словарь ключ - имя файла, значение - метка.
    """
    dict_ = {}
    for ind, name in enumerate(unique_names):
        dict_[name] = ind
    return dict_


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file',
                        type=Path,
                        default='../../data/csv/карта_проводок_clear.csv',
                        help='path to csv file.')

    args = parser.parse_args()

    df = pd.read_csv(args.input_file)
    df = df.fillna('')
    unique_table_name = df['Название таблицы'].unique()
    class_name = get_label(unique_table_name)
    global_model_path = 'intfloat/multilingual-e5-base'
    model = EmbeddingExtractor(global_model_path, DEVICE)

    new_df = pd.DataFrame(columns=[COLUMN_DOC_NAME,
                                   COLUMN_TABLE_NAME,
                                   COLUMN_TEXT,
                                   COLUMN_NAMES,
                                   COLUMN_LABELS_STR,
                                   COLUMN_TYPE_DOC_MAP,
                                   COLUMN_EMBEDDING,
                                   COLUMN_EMBEDDING_FULL
                                   ])

    for ind, row in tqdm(df.iterrows(), total=len(df)):

        cleaned_text = row['Хозяйственные операции'].split('\n')
        doc_name = row['Название файла']
        table_name = row['Название таблицы']
        try:
            column_names = [i.replace('\t', '').strip() for i in row['Columns'].split('\n')]
        except AttributeError:
            column_names = []

        type_docs = row['TypeDocs']
        if not type_docs:
            type_docs = '1C'

        for text in cleaned_text:
            if text != '':
                query_tokens = model.query_tokenization('passage: ' + text)
                query_embeds = model.query_embed_extraction(query_tokens.to(DEVICE), DO_NORMALIZATION)[0]

                query_tokens_full = model.query_tokenization(f'passage: {doc_name} {table_name} {text}')
                query_embeds_full = model.query_embed_extraction(query_tokens.to(DEVICE), DO_NORMALIZATION)[0]

                new_df.loc[len(new_df.index)] = [doc_name,
                                                 table_name,
                                                 text,
                                                 column_names,
                                                 class_name[table_name],
                                                 type_docs,
                                                 query_embeds,
                                                 query_embeds_full,
                                                 ]

    new_df.to_pickle(f'{args.input_file.parent}/{args.input_file.name[:-4]}.pkl')