INVOHIDE_inisw8 / ner.py
esun-choi's picture
Initial Commit
9b879f1
raw
history blame
2.74 kB
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
from collections import defaultdict
import torch
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda")
tokenizer = AutoTokenizer.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner")
model = AutoModelForTokenClassification.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner")
model.to(device)
def check_entity(entities):
for entity_info in entities:
entity_value = entity_info.get('entity', '').upper()
if 'LC' in entity_value or 'PS' in entity_value:
return 1
return 0
def ner(example):
ner = pipeline("ner", model=model, tokenizer=tokenizer,device=device)
ner_results = ner(example)
ner_results=check_entity(ner_results)
return ner_results
# ํ•˜๋‚˜
# def find_longest_value_key(input_dict):
# max_length = 0
# max_length_keys = []
# for key, value in input_dict.items():
# current_length = len(value)
# if current_length > max_length:
# max_length = current_length
# max_length_keys = [key]
# elif current_length == max_length:
# max_length_keys.append(key)
# if len(max_length_keys) == 1:
# return 0
# else:
# return 1
# def find_longest_value_key2(input_dict):
# if not input_dict:
# return None
# max_key = max(input_dict, key=lambda k: len(input_dict[k]))
# return max_key
# def find_most_frequent_entity(entities):
# entity_counts = defaultdict(list)
# for item in entities:
# split_entity = item['entity'].split('-')
# entity_type = split_entity[1]
# entity_counts[entity_type].append(item['score'])
# number=find_longest_value_key(entity_counts)
# if number==1:
# max_entities = []
# max_score_average = -1
# for entity, scores in entity_counts.items():
# score_average = sum(scores) / len(scores)
# if score_average > max_score_average:
# max_entities = [entity]
# max_score_average = score_average
# elif score_average == max_score_average:
# max_entities.append(entity)
# if len(max_entities)>0:
# return max_entities if len(max_entities) > 1 else max_entities[0]
# else:
# return "Do not mosaik"
# else:
# A=find_longest_value_key2(entity_counts)
# return A
# ํ•˜๋‚˜๋ผ๋„ ps ๋‚˜ lc ๊ฐ€ ์žˆ์œผ๋ฉด ๋ฐ”๋กœ ps , lc ๊บผ๋‚ด๊ธฐ
# label=filtering(ner_results)
# if label.find("PS")>-1 or label.find("LC")>-1:
# return 1
# else:
# return 0
#print(ner("ํ™๊ธธ๋™"))
#label=check_label(example)