File size: 2,287 Bytes
17d12d8
29e6656
7edc5be
e1b0f65
f5e679e
e1b0f65
5f853f6
 
 
 
 
d62b586
5f853f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237e24d
36c549c
 
 
 
 
 
 
 
c80b6f5
36c549c
2d067ac
f5e679e
5f853f6
 
 
 
 
 
 
088ef38
7b33051
 
 
 
 
088ef38
7b33051
 
 
 
 
088ef38
7b33051
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from openai import OpenAI
import os
from transformers import pipeline
from groq import Groq

groq_client = Groq(
     api_key=os.environ.get("groq_key"),
)

def generate_groq(text, model):
    completion = groq_client.chat.completions.create(
        model = model,
        messages=[
            {
                "role": "user",
                "content": text
            },
            {
                "role": "assistant",
                "content": "Please follow the instruction and write about the given topic in approximately the given number of words"
            }    
        ],
        temperature=1,
        max_tokens=1024,
        top_p=1,
        stream=True,
        stop=None,
    )
    response = ""
    for i, chunk in enumerate(completion):
        if i != 0:
            response += chunk.choices[0].delta.content or ""
    return response

def generate_openai(text, model, openai_client):
    message=[{"role": "user", "content": text}]
    response = openai_client.chat.completions.create(
        model=model,
        messages = message,
        temperature=0.2,
        max_tokens=800,
        frequency_penalty=0.0
    )
    return response.choices[0].message.content

def generate(text, model, api):
    if model == "Llama 3":
        return generate_groq(text, "llama3-70b-8192")
    elif model == "Groq":
        return generate_groq(text, "llama3-groq-70b-8192-tool-use-preview")
    elif model == "Mistral":
        return generate_groq(text, "mixtral-8x7b-32768")    
    elif model == "Gemma":
        return generate_groq(text, "gemma2-9b-it")    
    elif model == "OpenAI GPT 3.5":
        try:
            openai_client = OpenAI(api_key=api)
            return generate_openai(text, "gpt-3.5-turbo", openai_client)
        except:
            return "Please add a valid API key"
    elif model == "OpenAI GPT 4":
        try:
            openai_client = OpenAI(api_key=api)
            return generate_openai(text, "gpt-4-turbo", openai_client)
        except:
            return "Please add a valid API key"
    elif model == "OpenAI GPT 4o":
        try:
            openai_client = OpenAI(api_key=api)
            return generate_openai(text, "gpt-4o", openai_client)
        except:
            return "Please add a valid API key"