article_writer / ai_generate.py
minko186's picture
Update ai_generate.py
fd96c74 verified
raw
history blame
1.8 kB
import torch
from openai import OpenAI
import os
from transformers import pipeline
# pipes = {
# 'GPT-Neo': pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B"),
# 'Llama 3': pipeline("text-generation", model="meta-llama/Meta-Llama-3-8B")
# }
def generate(text, model, api):
if model == "GPT-Neo":
response = pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B")(text)
return response[0]
elif model == "Llama 3":
response = pipeline("text-generation", model="meta-llama/Meta-Llama-3-8B")(text)
return response[0]
elif model == "OpenAI GPT 3.5":
client = OpenAI(
api_key=api,
)
message=[{"role": "user", "content": text}]
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages = message,
temperature=0.2,
max_tokens=800,
frequency_penalty=0.0
)
return response[0].message.content
elif model == "OpenAI GPT 4":
client = OpenAI(
api_key=api,
)
message=[{"role": "user", "content": text}]
response = client.chat.completions.create(
model="gpt-4-turbo",
messages = message,
temperature=0.2,
max_tokens=800,
frequency_penalty=0.0
)
return response[0].message.content
elif model == "OpenAI GPT 4o":
client = OpenAI(
api_key=api,
)
message=[{"role": "user", "content": text}]
response = client.chat.completions.create(
model="gpt-4o",
messages = message,
temperature=0.2,
max_tokens=800,
frequency_penalty=0.0
)
return response[0].message.content