|
import uvicorn |
|
import threading |
|
from typing import Optional |
|
from transformers import pipeline |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
import pandas as pd |
|
|
|
from pprint import pprint |
|
|
|
import gradio as gr |
|
from transformers import pipeline |
|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from typing import List, Dict |
|
|
|
|
|
app = FastAPI() |
|
model_cache: Optional[object] = None |
|
|
|
def load_model(): |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav") |
|
model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav") |
|
|
|
id2label = model.config.id2label |
|
|
|
print(f"Can recognise the following labels {id2label}") |
|
|
|
|
|
|
|
model = pipeline("ner", model=model, tokenizer = tokenizer) |
|
return model |
|
|
|
def load_plod_cw_dataset(): |
|
from datasets import load_dataset |
|
dataset = load_dataset("surrey-nlp/PLOD-CW") |
|
return dataset |
|
|
|
def get_cached_model(): |
|
global model_cache |
|
if model_cache is None: |
|
model_cache = load_model() |
|
return model_cache |
|
|
|
|
|
model = get_cached_model() |
|
|
|
|
|
|
|
class Entity(BaseModel): |
|
entity: str |
|
score: float |
|
start: int |
|
end: int |
|
word: str |
|
|
|
class NERResponse(BaseModel): |
|
entities: List[Entity] |
|
|
|
class NERRequest(BaseModel): |
|
text: str |
|
|
|
@app.get("/hello") |
|
def read_root(): |
|
return {"message": "Hello, World!"} |
|
|
|
|
|
@app.post("/ner", response_model=NERResponse) |
|
def get_entities(request: NERRequest): |
|
print(request) |
|
model = get_cached_model() |
|
|
|
entities = model(request.text) |
|
print(entities[0].keys()) |
|
|
|
response_entities = [Entity(**entity) for entity in entities] |
|
print(response_entities[0]) |
|
return NERResponse(entities=response_entities) |
|
|
|
def get_color_for_label(label: str) -> str: |
|
|
|
color_mapping = { |
|
"I-LF": "red", |
|
"B-AC": "blue", |
|
"LOC": "green", |
|
|
|
} |
|
return color_mapping.get(label, "black") |
|
|
|
|
|
|
|
def ner_demo(text): |
|
model = get_cached_model() |
|
entities = model(text) |
|
|
|
|
|
|
|
color_coded_text = text |
|
for entity in entities: |
|
|
|
start, end, label = entity["start"], entity["end"], entity["entity"] |
|
color = get_color_for_label(label) |
|
entity_text = text[start:end] |
|
colored_entity = f'<span style="color: {color}; font-weight: bold;">{entity_text}</span>' |
|
color_coded_text = color_coded_text[:start] + colored_entity + color_coded_text[end:] |
|
|
|
return color_coded_text |
|
|
|
PROJECT_INTRO = "This is a HF Spaces hosted Gradio App built by NLP Group 27 . The model has been trained on surrey-nlp/PLOD-CW dataset" |
|
|
|
demo = gr.Interface( |
|
fn=ner_demo, |
|
inputs=gr.Textbox(lines=10, placeholder="Enter text here..."), |
|
outputs="html", |
|
|
|
title="Named Entity Recognition on PLOD-CW ", |
|
description=f"{PROJECT_INTRO}\n\nEnter text to extract named entities using a NER model." |
|
) |
|
|
|
|
|
def run_fastapi(): |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
|
|
def run_gradio(): |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
|
|
threading.Thread(target=run_fastapi).start() |
|
threading.Thread(target=run_gradio).start() |
|
|
|
|