import argparse import pandas as pd from pathlib import Path from sklearn.metrics import f1_score from common.constants import DEVICE from common.constants import DO_NORMALIZATION from common.constants import COLUMN_LABELS_STR from common.constants import COLUMN_TEXT from components.faiss_vector_database import FaissVectorDatabase from components.embedding_extraction import EmbeddingExtractor def main(query): global database global model # Здесь можно добавить любую обработку текста cleaned_text = query.replace("\n", " ") cleaned_text = 'query: ' + cleaned_text query_tokens = model.query_tokenization(cleaned_text) query_embeds = model.query_embed_extraction(query_tokens.to(DEVICE), DO_NORMALIZATION)[0] query_embeds = query_embeds[None, :] answer = database.search_transaction_map(query_embeds, K_NEIGHBORS) return answer[0] K_NEIGHBORS = 1 if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input_file', type=Path, default='../../data/csv/Карта проводок clear_3_3.pkl', help='path to csv file.') args = parser.parse_args() global_model_path = 'intfloat/multilingual-e5-base' # db_files_path = '../../data/csv/Карта проводок clear_3_3.pkl' df = pd.read_pickle(args.input_file) database = FaissVectorDatabase(args.input_file) model = EmbeddingExtractor(global_model_path, DEVICE) true, pred = [], [] for ind, row in df.iterrows(): answer = main(row[COLUMN_TEXT]) true.append(row[COLUMN_LABELS_STR]) pred.append(answer[COLUMN_LABELS_STR]) if row[COLUMN_LABELS_STR] != answer[COLUMN_LABELS_STR]: print(f'True labels: {row[COLUMN_LABELS_STR]}-----Pred labels: {answer[COLUMN_LABELS_STR]}') print('Название таблицы', row['TableName']) print(row['DocName'], '---', answer['doc_name']) print('-----------------------------------------') print('macro', f1_score(true, pred, average='macro')) print('micro', f1_score(true, pred, average='micro'))