Spaces:
Sleeping
Sleeping
File size: 3,837 Bytes
b24d496 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import argparse
from typing import List, Dict
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from common.constants import COLUMN_DOC_NAME
from common.constants import COLUMN_EMBEDDING
from common.constants import COLUMN_EMBEDDING_FULL
from common.constants import COLUMN_LABELS_STR
from common.constants import COLUMN_NAMES
from common.constants import COLUMN_TABLE_NAME
from common.constants import COLUMN_TEXT
from common.constants import DEVICE
from common.constants import DO_NORMALIZATION
from common.constants import COLUMN_TYPE_DOC_MAP
from components.embedding_extraction import EmbeddingExtractor
def get_label(unique_names: List) -> Dict[str, int]:
"""
Генерирует метки исходя из количества уникальных названий файлов.
Args:
unique_names: Список уникальных наименований файлов.
Returns:
Возвращает словарь ключ - имя файла, значение - метка.
"""
dict_ = {}
for ind, name in enumerate(unique_names):
dict_[name] = ind
return dict_
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_file',
type=Path,
default='../../data/csv/карта_проводок_clear.csv',
help='path to csv file.')
args = parser.parse_args()
df = pd.read_csv(args.input_file)
df = df.fillna('')
unique_table_name = df['Название таблицы'].unique()
class_name = get_label(unique_table_name)
global_model_path = 'intfloat/multilingual-e5-base'
model = EmbeddingExtractor(global_model_path, DEVICE)
new_df = pd.DataFrame(columns=[COLUMN_DOC_NAME,
COLUMN_TABLE_NAME,
COLUMN_TEXT,
COLUMN_NAMES,
COLUMN_LABELS_STR,
COLUMN_TYPE_DOC_MAP,
COLUMN_EMBEDDING,
COLUMN_EMBEDDING_FULL
])
for ind, row in tqdm(df.iterrows(), total=len(df)):
cleaned_text = row['Хозяйственные операции'].split('\n')
doc_name = row['Название файла']
table_name = row['Название таблицы']
try:
column_names = [i.replace('\t', '').strip() for i in row['Columns'].split('\n')]
except AttributeError:
column_names = []
type_docs = row['TypeDocs']
if not type_docs:
type_docs = '1C'
for text in cleaned_text:
if text != '':
query_tokens = model.query_tokenization('passage: ' + text)
query_embeds = model.query_embed_extraction(query_tokens.to(DEVICE), DO_NORMALIZATION)[0]
query_tokens_full = model.query_tokenization(f'passage: {doc_name} {table_name} {text}')
query_embeds_full = model.query_embed_extraction(query_tokens.to(DEVICE), DO_NORMALIZATION)[0]
new_df.loc[len(new_df.index)] = [doc_name,
table_name,
text,
column_names,
class_name[table_name],
type_docs,
query_embeds,
query_embeds_full,
]
new_df.to_pickle(f'{args.input_file.parent}/{args.input_file.name[:-4]}.pkl')
|