Medical_QnA / app.py
chsubhasis's picture
app and requirements files added
706a667
raw
history blame
2 kB
import gradio as gr
from transformers import AutoModelWithLMHead, AutoTokenizer
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, DataCollatorForLanguageModeling
#def greet(name):
# return "Hello " + name + "!!"
username = "chsubhasis"
my_repo = "Medical-QnA-gpt2"
my_checkpoint = username + '/' + my_repo
loaded_model = AutoModelWithLMHead.from_pretrained(my_checkpoint)
loaded_tokenizer = AutoTokenizer.from_pretrained(my_checkpoint)
def generate_response(model, tokenizer, prompt, max_length=200):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Check the device of the model
device = next(model.parameters()).device
# Move input_ids to the same device as the model
input_ids = input_ids.to(device)
# Create the attention mask and pad token id
attention_mask = torch.ones_like(input_ids)
pad_token_id = tokenizer.eos_token_id
output = model.generate(
input_ids,
max_length=max_length,
num_return_sequences=1,
attention_mask=attention_mask,
pad_token_id=pad_token_id
)
#return tokenizer.decode(output[0], skip_special_tokens=True)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
generated_text = generated_text.replace("<answer>", " ").replace("<end>", " ")
return generated_text
def generate_query_response(prompt, max_length=200):
model = loaded_model
tokenizer = loaded_tokenizer
return generate_response(model, tokenizer, prompt, max_length)
#demo = gr.Interface(fn=greet, inputs="text", outputs="text")
#demo.launch()
iface = gr.Interface(
fn=generate_query_response,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your medical query here..."),
gr.Slider(minimum=50, maximum=500, value=200, label="Maximum Length")
],
outputs="text",
title="Medical Question Answering Bot",
description="Ask your medical questions to get relevant answers."
)
iface.launch(share=True)