bewchatbot / app.py
william4416's picture
Update app.py
d78f75a verified
raw
history blame
1.67 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
app = FastAPI()
# Load DialoGPT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
# Load courses data from JSON file
with open("uts_courses.json", "r") as file:
courses_data = json.load(file)
class UserInput(BaseModel):
user_input: str
def generate_response(user_input: str):
if user_input.lower() == "help":
return "I can help you with information about UTS courses. Feel free to ask!"
elif user_input.lower() == "exit":
return "Goodbye!"
elif user_input.lower() == "list courses":
course_list = "\n".join([f"{category}: {', '.join(courses)}" for category, courses in courses_data["courses"].items()])
return f"Here are the available courses:\n{course_list}"
elif user_input.lower() in courses_data["courses"]:
return f"The courses in {user_input} are: {', '.join(courses_data['courses'][user_input])}"
else:
# Tokenize the user input
input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
# Generate a response
response_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id)
# Decode the response
response = tokenizer.decode(response_ids[0], skip_special_tokens=True)
return response
@app.post("/")
def chat(user_input: UserInput):
response = generate_response(user_input.user_input)
return {"response": response}