Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
import torch | |
# Load pre-trained model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("ahmed792002/alzheimers_memory_support_ai") | |
model = AutoModelForCausalLM.from_pretrained("ahmed792002/alzheimers_memory_support_ai") | |
# Chatbot function | |
def chatbot(query, history, system_message, max_length, temperature, top_k, top_p): | |
""" | |
Processes a user query through the specified model to generate a response. | |
""" | |
# Tokenize input query | |
input_ids = tokenizer.encode(query, return_tensors="pt") | |
response = '.' | |
while response=='.': | |
# Generate text using the model | |
final_outputs = model.generate( | |
input_ids, | |
do_sample=True, | |
max_length=int(max_length), # Convert max_length to integer | |
temperature=float(temperature), # Convert temperature to float | |
top_k=int(top_k), # Convert top_k to integer | |
top_p=float(top_p), # Convert top_p to float | |
pad_token_id=tokenizer.pad_token_id, | |
) | |
# Decode generated text | |
response = tokenizer.decode(final_outputs[0], skip_special_tokens=True) | |
response = response.split('"')[1] | |
return response | |
# Gradio ChatInterface | |
demo = gr.ChatInterface( | |
chatbot, | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly chatbot.", label="System message"), | |
gr.Slider(128, 1024, value=256, step=64, label="Max Length"), # Slider for max_length | |
gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"), # Slider for temperature | |
gr.Slider(1, 100, value=50, step=1, label="Top-K"), # Slider for top_k | |
gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-P"), # Slider for top_p | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) # Set `share=True` to create a public link | |