Spaces:
Runtime error
Runtime error
import torch | |
from openai import OpenAI | |
import os | |
from transformers import pipeline | |
from groq import Groq | |
groq_client = Groq( | |
api_key=os.environ.get("groq_key"), | |
) | |
def generate_groq(text, model): | |
completion = groq_client.chat.completions.create( | |
model=model, | |
messages=[ | |
{"role": "user", "content": text}, | |
{ | |
"role": "assistant", | |
"content": "Please follow the instruction and write about the given topic in approximately the given number of words", | |
}, | |
], | |
temperature=1, | |
max_tokens=1024, | |
top_p=1, | |
stream=True, | |
stop=None, | |
) | |
response = "" | |
for i, chunk in enumerate(completion): | |
if i != 0: | |
response += chunk.choices[0].delta.content or "" | |
return response | |
def generate_openai(text, model, openai_client): | |
message = [{"role": "user", "content": text}] | |
response = openai_client.chat.completions.create( | |
model=model, messages=message, temperature=0.2, max_tokens=800, frequency_penalty=0.0 | |
) | |
return response.choices[0].message.content | |
def generate(text, model, api): | |
if model == "Llama 3": | |
return generate_groq(text, "llama3-70b-8192") | |
elif model == "Groq": | |
return generate_groq(text, "llama3-groq-70b-8192-tool-use-preview") | |
elif model == "Mistral": | |
return generate_groq(text, "mixtral-8x7b-32768") | |
elif model == "Gemma": | |
return generate_groq(text, "gemma2-9b-it") | |
elif model == "OpenAI GPT 3.5": | |
try: | |
openai_client = OpenAI(api_key=api) | |
return generate_openai(text, "gpt-3.5-turbo", openai_client) | |
except: | |
return "Please add a valid API key" | |
elif model == "OpenAI GPT 4": | |
try: | |
openai_client = OpenAI(api_key=api) | |
return generate_openai(text, "gpt-4-turbo", openai_client) | |
except: | |
return "Please add a valid API key" | |
elif model == "OpenAI GPT 4o": | |
try: | |
openai_client = OpenAI(api_key=api) | |
return generate_openai(text, "gpt-4o", openai_client) | |
except: | |
return "Please add a valid API key" | |