minko186 commited on
Commit
f5e679e
·
verified ·
1 Parent(s): 65d485c

Update ai_generate.py

Browse files
Files changed (1) hide show
  1. ai_generate.py +27 -11
ai_generate.py CHANGED
@@ -2,19 +2,35 @@ import torch
2
  from openai import OpenAI
3
  import os
4
  from transformers import pipeline
5
-
6
- # pipes = {
7
- # 'GPT-Neo': pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B"),
8
- # 'Llama 3': pipeline("text-generation", model="meta-llama/Meta-Llama-3-8B")
9
- # }
10
 
11
  def generate(text, model, api):
12
- if model == "GPT-Neo":
13
- response = pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B")(text)
14
- print(response)
15
- return response[0]
16
- elif model == "Llama 3":
17
- response = pipeline("text-generation", model="meta-llama/Meta-Llama-3-8B")(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  print(response)
19
  return response[0]
20
  elif model == "OpenAI GPT 3.5":
 
2
  from openai import OpenAI
3
  import os
4
  from transformers import pipeline
5
+ from groq import Groq
 
 
 
 
6
 
7
  def generate(text, model, api):
8
+ if model == "Llama 3":
9
+ client = Groq(
10
+ api_key=os.environ.get("groq_key"),
11
+ )
12
+ completion = client.chat.completions.create(
13
+ model="llama3-8b-8192",
14
+ messages=[
15
+ {
16
+ "role": "user",
17
+ "content": text
18
+ },
19
+ {
20
+ "role": "assistant",
21
+ "content": "Please follow the instruction and write about the given topic in approximately the given number of words"
22
+ }
23
+ ],
24
+ temperature=1,
25
+ max_tokens=1024,
26
+ top_p=1,
27
+ stream=True,
28
+ stop=None,
29
+ )
30
+
31
+ for chunk in completion:
32
+ print(chunk.choices[0].delta.content or "", end="")
33
+
34
  print(response)
35
  return response[0]
36
  elif model == "OpenAI GPT 3.5":