Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelWithLMHead, AutoTokenizer | |
import torch | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, DataCollatorForLanguageModeling | |
#def greet(name): | |
# return "Hello " + name + "!!" | |
username = "chsubhasis" | |
my_repo = "Medical-QnA-gpt2" | |
my_checkpoint = username + '/' + my_repo | |
loaded_model = AutoModelWithLMHead.from_pretrained(my_checkpoint) | |
loaded_tokenizer = AutoTokenizer.from_pretrained(my_checkpoint) | |
def generate_response(model, tokenizer, prompt, max_length=200): | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
# Check the device of the model | |
device = next(model.parameters()).device | |
# Move input_ids to the same device as the model | |
input_ids = input_ids.to(device) | |
# Create the attention mask and pad token id | |
attention_mask = torch.ones_like(input_ids) | |
pad_token_id = tokenizer.eos_token_id | |
output = model.generate( | |
input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
attention_mask=attention_mask, | |
pad_token_id=pad_token_id | |
) | |
#return tokenizer.decode(output[0], skip_special_tokens=True) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
generated_text = generated_text.replace("<answer>", " ").replace("<end>", " ") | |
return generated_text | |
def generate_query_response(prompt, max_length=200): | |
model = loaded_model | |
tokenizer = loaded_tokenizer | |
return generate_response(model, tokenizer, prompt, max_length) | |
#demo = gr.Interface(fn=greet, inputs="text", outputs="text") | |
#demo.launch() | |
iface = gr.Interface( | |
fn=generate_query_response, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Enter your medical query here..."), | |
gr.Slider(minimum=50, maximum=500, value=200, label="Maximum Length") | |
], | |
outputs="text", | |
title="Medical Question Answering Bot", | |
description="Ask your medical questions to get relevant answers." | |
) | |
iface.launch(share=True) | |