article_writer / ai_generate.py
minko186's picture
Update ai_generate.py
f5e679e verified
raw
history blame
2.21 kB
import torch
from openai import OpenAI
import os
from transformers import pipeline
from groq import Groq
def generate(text, model, api):
if model == "Llama 3":
client = Groq(
api_key=os.environ.get("groq_key"),
)
completion = client.chat.completions.create(
model="llama3-8b-8192",
messages=[
{
"role": "user",
"content": text
},
{
"role": "assistant",
"content": "Please follow the instruction and write about the given topic in approximately the given number of words"
}
],
temperature=1,
max_tokens=1024,
top_p=1,
stream=True,
stop=None,
)
for chunk in completion:
print(chunk.choices[0].delta.content or "", end="")
print(response)
return response[0]
elif model == "OpenAI GPT 3.5":
client = OpenAI(
api_key=api,
)
message=[{"role": "user", "content": text}]
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = message,
temperature=0.2,
max_tokens=800,
frequency_penalty=0.0
)
return response[0].message.content
elif model == "OpenAI GPT 4":
client = OpenAI(
api_key=api,
)
message=[{"role": "user", "content": text}]
response = client.chat.completions.create(
model="gpt-4-turbo",
messages = message,
temperature=0.2,
max_tokens=800,
frequency_penalty=0.0
)
return response[0].message.content
elif model == "OpenAI GPT 4o":
client = OpenAI(
api_key=api,
)
message=[{"role": "user", "content": text}]
response = client.chat.completions.create(
model="gpt-4o",
messages = message,
temperature=0.2,
max_tokens=800,
frequency_penalty=0.0
)
return response[0].message.content