minko186 commited on
Commit
5f853f6
·
verified ·
1 Parent(s): 0fc1c20

Update ai_generate.py

Browse files
Files changed (1) hide show
  1. ai_generate.py +36 -27
ai_generate.py CHANGED
@@ -4,35 +4,44 @@ import os
4
  from transformers import pipeline
5
  from groq import Groq
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def generate(text, model, api):
8
  if model == "Llama 3":
9
- client = Groq(
10
- api_key=os.environ.get("groq_key"),
11
- )
12
- completion = client.chat.completions.create(
13
- model="llama3-8b-8192",
14
- messages=[
15
- {
16
- "role": "user",
17
- "content": text
18
- },
19
- {
20
- "role": "assistant",
21
- "content": "Please follow the instruction and write about the given topic in approximately the given number of words"
22
- }
23
- ],
24
- temperature=1,
25
- max_tokens=1024,
26
- top_p=1,
27
- stream=True,
28
- stop=None,
29
- )
30
-
31
- response = ""
32
- for i, chunk in enumerate(completion):
33
- if i != 0:
34
- response += chunk.choices[0].delta.content or ""
35
- return response
36
  elif model == "OpenAI GPT 3.5":
37
  client = OpenAI(
38
  api_key=api,
 
4
  from transformers import pipeline
5
  from groq import Groq
6
 
7
+ groq_client = Groq(
8
+ api_key=os.environ.get("groq_key"),
9
+ )
10
+
11
+ def generate_groq(text, model):
12
+ completion = client.chat.completions.create(
13
+ model = model,
14
+ messages=[
15
+ {
16
+ "role": "user",
17
+ "content": text
18
+ },
19
+ {
20
+ "role": "assistant",
21
+ "content": "Please follow the instruction and write about the given topic in approximately the given number of words"
22
+ }
23
+ ],
24
+ temperature=1,
25
+ max_tokens=1024,
26
+ top_p=1,
27
+ stream=True,
28
+ stop=None,
29
+ )
30
+ response = ""
31
+ for i, chunk in enumerate(completion):
32
+ if i != 0:
33
+ response += chunk.choices[0].delta.content or ""
34
+ return response
35
+
36
  def generate(text, model, api):
37
  if model == "Llama 3":
38
+ return generate_groq(text, "llama3-70b-8192")
39
+ elif model == "Groq":
40
+ return generate_groq(text, "llama3-groq-70b-8192-tool-use-preview")
41
+ elif model == "Mistral":
42
+ return generate_groq(text, "mixtral-8x7b-32768")
43
+ elif model == "Gemma":
44
+ return generate_groq(text, "gemma2-9b-it")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  elif model == "OpenAI GPT 3.5":
46
  client = OpenAI(
47
  api_key=api,