Spaces:
Sleeping
Sleeping
import argparse | |
import pandas as pd | |
from pathlib import Path | |
from sklearn.metrics import f1_score | |
from common.constants import DEVICE | |
from common.constants import DO_NORMALIZATION | |
from common.constants import COLUMN_LABELS_STR | |
from common.constants import COLUMN_TEXT | |
from components.faiss_vector_database import FaissVectorDatabase | |
from components.embedding_extraction import EmbeddingExtractor | |
def main(query): | |
global database | |
global model | |
# Здесь можно добавить любую обработку текста | |
cleaned_text = query.replace("\n", " ") | |
cleaned_text = 'query: ' + cleaned_text | |
query_tokens = model.query_tokenization(cleaned_text) | |
query_embeds = model.query_embed_extraction(query_tokens.to(DEVICE), DO_NORMALIZATION)[0] | |
query_embeds = query_embeds[None, :] | |
answer = database.search_transaction_map(query_embeds, K_NEIGHBORS) | |
return answer[0] | |
K_NEIGHBORS = 1 | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input_file', | |
type=Path, | |
default='../../data/csv/Карта проводок clear_3_3.pkl', | |
help='path to csv file.') | |
args = parser.parse_args() | |
global_model_path = 'intfloat/multilingual-e5-base' | |
# db_files_path = '../../data/csv/Карта проводок clear_3_3.pkl' | |
df = pd.read_pickle(args.input_file) | |
database = FaissVectorDatabase(args.input_file) | |
model = EmbeddingExtractor(global_model_path, DEVICE) | |
true, pred = [], [] | |
for ind, row in df.iterrows(): | |
answer = main(row[COLUMN_TEXT]) | |
true.append(row[COLUMN_LABELS_STR]) | |
pred.append(answer[COLUMN_LABELS_STR]) | |
if row[COLUMN_LABELS_STR] != answer[COLUMN_LABELS_STR]: | |
print(f'True labels: {row[COLUMN_LABELS_STR]}-----Pred labels: {answer[COLUMN_LABELS_STR]}') | |
print('Название таблицы', row['TableName']) | |
print(row['DocName'], '---', answer['doc_name']) | |
print('-----------------------------------------') | |
print('macro', f1_score(true, pred, average='macro')) | |
print('micro', f1_score(true, pred, average='micro')) |