from transformers import AutoTokenizer, AutoModelForTokenClassification from transformers import pipeline from collections import defaultdict import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #device = torch.device("cuda") tokenizer = AutoTokenizer.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner") model = AutoModelForTokenClassification.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner") model.to(device) def check_entity(entities): for entity_info in entities: entity_value = entity_info.get('entity', '').upper() if 'LC' in entity_value or 'PS' in entity_value: return 1 return 0 def ner(example): ner = pipeline("ner", model=model, tokenizer=tokenizer,device=device) ner_results = ner(example) ner_results=check_entity(ner_results) return ner_results # 하나 # def find_longest_value_key(input_dict): # max_length = 0 # max_length_keys = [] # for key, value in input_dict.items(): # current_length = len(value) # if current_length > max_length: # max_length = current_length # max_length_keys = [key] # elif current_length == max_length: # max_length_keys.append(key) # if len(max_length_keys) == 1: # return 0 # else: # return 1 # def find_longest_value_key2(input_dict): # if not input_dict: # return None # max_key = max(input_dict, key=lambda k: len(input_dict[k])) # return max_key # def find_most_frequent_entity(entities): # entity_counts = defaultdict(list) # for item in entities: # split_entity = item['entity'].split('-') # entity_type = split_entity[1] # entity_counts[entity_type].append(item['score']) # number=find_longest_value_key(entity_counts) # if number==1: # max_entities = [] # max_score_average = -1 # for entity, scores in entity_counts.items(): # score_average = sum(scores) / len(scores) # if score_average > max_score_average: # max_entities = [entity] # max_score_average = score_average # elif score_average == max_score_average: # max_entities.append(entity) # if len(max_entities)>0: # return max_entities if len(max_entities) > 1 else max_entities[0] # else: # return "Do not mosaik" # else: # A=find_longest_value_key2(entity_counts) # return A # 하나라도 ps 나 lc 가 있으면 바로 ps , lc 꺼내기 # label=filtering(ner_results) # if label.find("PS")>-1 or label.find("LC")>-1: # return 1 # else: # return 0 #print(ner("홍길동")) #label=check_label(example)