File size: 2,200 Bytes
17d12d8
29e6656
7edc5be
e1b0f65
f5e679e
e1b0f65
5f853f6
cf245ed
5f853f6
 
cf245ed
5f853f6
d62b586
cf245ed
5f853f6
cf245ed
5f853f6
 
cf245ed
 
5f853f6
 
 
 
 
 
 
 
 
 
 
 
 
cf245ed
237e24d
cf245ed
36c549c
cf245ed
36c549c
c80b6f5
36c549c
cf245ed
2d067ac
f5e679e
5f853f6
 
 
 
cf245ed
5f853f6
cf245ed
088ef38
7b33051
 
 
 
 
088ef38
7b33051
 
 
 
 
088ef38
7b33051
 
 
 
cf245ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from openai import OpenAI
import os
from transformers import pipeline
from groq import Groq

groq_client = Groq(
    api_key=os.environ.get("groq_key"),
)


def generate_groq(text, model):
    completion = groq_client.chat.completions.create(
        model=model,
        messages=[
            {"role": "user", "content": text},
            {
                "role": "assistant",
                "content": "Please follow the instruction and write about the given topic in approximately the given number of words",
            },
        ],
        temperature=1,
        max_tokens=1024,
        top_p=1,
        stream=True,
        stop=None,
    )
    response = ""
    for i, chunk in enumerate(completion):
        if i != 0:
            response += chunk.choices[0].delta.content or ""
    return response


def generate_openai(text, model, openai_client):
    message = [{"role": "user", "content": text}]
    response = openai_client.chat.completions.create(
        model=model, messages=message, temperature=0.2, max_tokens=800, frequency_penalty=0.0
    )
    return response.choices[0].message.content


def generate(text, model, api):
    if model == "Llama 3":
        return generate_groq(text, "llama3-70b-8192")
    elif model == "Groq":
        return generate_groq(text, "llama3-groq-70b-8192-tool-use-preview")
    elif model == "Mistral":
        return generate_groq(text, "mixtral-8x7b-32768")
    elif model == "Gemma":
        return generate_groq(text, "gemma2-9b-it")
    elif model == "OpenAI GPT 3.5":
        try:
            openai_client = OpenAI(api_key=api)
            return generate_openai(text, "gpt-3.5-turbo", openai_client)
        except:
            return "Please add a valid API key"
    elif model == "OpenAI GPT 4":
        try:
            openai_client = OpenAI(api_key=api)
            return generate_openai(text, "gpt-4-turbo", openai_client)
        except:
            return "Please add a valid API key"
    elif model == "OpenAI GPT 4o":
        try:
            openai_client = OpenAI(api_key=api)
            return generate_openai(text, "gpt-4o", openai_client)
        except:
            return "Please add a valid API key"