Spaces:
Sleeping
Sleeping
import argparse | |
from typing import List, Dict | |
import pandas as pd | |
from pathlib import Path | |
from tqdm import tqdm | |
from common.constants import COLUMN_DOC_NAME | |
from common.constants import COLUMN_EMBEDDING | |
from common.constants import COLUMN_EMBEDDING_FULL | |
from common.constants import COLUMN_LABELS_STR | |
from common.constants import COLUMN_NAMES | |
from common.constants import COLUMN_TABLE_NAME | |
from common.constants import COLUMN_TEXT | |
from common.constants import DEVICE | |
from common.constants import DO_NORMALIZATION | |
from common.constants import COLUMN_TYPE_DOC_MAP | |
from components.embedding_extraction import EmbeddingExtractor | |
def get_label(unique_names: List) -> Dict[str, int]: | |
""" | |
Генерирует метки исходя из количества уникальных названий файлов. | |
Args: | |
unique_names: Список уникальных наименований файлов. | |
Returns: | |
Возвращает словарь ключ - имя файла, значение - метка. | |
""" | |
dict_ = {} | |
for ind, name in enumerate(unique_names): | |
dict_[name] = ind | |
return dict_ | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input_file', | |
type=Path, | |
default='../../data/csv/карта_проводок_clear.csv', | |
help='path to csv file.') | |
args = parser.parse_args() | |
df = pd.read_csv(args.input_file) | |
df = df.fillna('') | |
unique_table_name = df['Название таблицы'].unique() | |
class_name = get_label(unique_table_name) | |
global_model_path = 'intfloat/multilingual-e5-base' | |
model = EmbeddingExtractor(global_model_path, DEVICE) | |
new_df = pd.DataFrame(columns=[COLUMN_DOC_NAME, | |
COLUMN_TABLE_NAME, | |
COLUMN_TEXT, | |
COLUMN_NAMES, | |
COLUMN_LABELS_STR, | |
COLUMN_TYPE_DOC_MAP, | |
COLUMN_EMBEDDING, | |
COLUMN_EMBEDDING_FULL | |
]) | |
for ind, row in tqdm(df.iterrows(), total=len(df)): | |
cleaned_text = row['Хозяйственные операции'].split('\n') | |
doc_name = row['Название файла'] | |
table_name = row['Название таблицы'] | |
try: | |
column_names = [i.replace('\t', '').strip() for i in row['Columns'].split('\n')] | |
except AttributeError: | |
column_names = [] | |
type_docs = row['TypeDocs'] | |
if not type_docs: | |
type_docs = '1C' | |
for text in cleaned_text: | |
if text != '': | |
query_tokens = model.query_tokenization('passage: ' + text) | |
query_embeds = model.query_embed_extraction(query_tokens.to(DEVICE), DO_NORMALIZATION)[0] | |
query_tokens_full = model.query_tokenization(f'passage: {doc_name} {table_name} {text}') | |
query_embeds_full = model.query_embed_extraction(query_tokens.to(DEVICE), DO_NORMALIZATION)[0] | |
new_df.loc[len(new_df.index)] = [doc_name, | |
table_name, | |
text, | |
column_names, | |
class_name[table_name], | |
type_docs, | |
query_embeds, | |
query_embeds_full, | |
] | |
new_df.to_pickle(f'{args.input_file.parent}/{args.input_file.name[:-4]}.pkl') | |