davanstrien HF staff commited on
Commit
37a83d8
Β·
1 Parent(s): 3ef94a5
Files changed (1) hide show
  1. app.py +46 -16
app.py CHANGED
@@ -16,22 +16,36 @@ from theme import TufteInspired
16
  # Ensure you're logged in to Hugging Face
17
  login(os.getenv("HF_TOKEN"))
18
 
 
 
 
 
 
 
19
 
20
 
21
- client = OpenAI(
22
- base_url="https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct/v1",
23
- api_key=get_token(),
24
- )
 
 
 
 
 
 
25
 
26
  # Set up dataset storage
27
  dataset_folder = Path("dataset")
28
  dataset_folder.mkdir(exist_ok=True)
29
 
 
30
  # Function to get the latest dataset file
31
  def get_latest_dataset_file():
32
  files = list(dataset_folder.glob("data_*.jsonl"))
33
  return max(files, key=os.path.getctime) if files else None
34
 
 
35
  # Check for existing dataset and create or append to it
36
  if latest_file := get_latest_dataset_file():
37
  dataset_file = latest_file
@@ -53,18 +67,23 @@ scheduler = CommitScheduler(
53
  # Global dictionary to store votes
54
  votes = {}
55
 
 
56
  def generate_prompt():
57
  if random.choice([True, False]):
58
  return detailed_genre_description_prompt()
59
  else:
60
  return basic_prompt()
61
 
 
62
  def get_and_store_prompt():
63
  prompt = generate_prompt()
64
  print(prompt) # Keep this for debugging
65
  return prompt
66
 
 
67
  def generate_blurb(prompt):
 
 
68
  max_tokens = random.randint(100, 1000)
69
  chat_completion = client.chat.completions.create(
70
  model="tgi",
@@ -77,12 +96,17 @@ def generate_blurb(prompt):
77
  full_text = ""
78
  for message in chat_completion:
79
  full_text += message.choices[0].delta.content
80
- yield full_text
 
 
81
 
82
  def generate_vote_id(user_id, blurb):
83
  return hashlib.md5(f"{user_id}:{blurb}".encode()).hexdigest()
84
 
85
- def log_blurb_and_vote(prompt, blurb, vote, user_info: gr.OAuthProfile | None, *args):
 
 
 
86
  user_id = user_info.username if user_info is not None else str(uuid.uuid4())
87
  vote_id = generate_vote_id(user_id, blurb)
88
 
@@ -98,14 +122,16 @@ def log_blurb_and_vote(prompt, blurb, vote, user_info: gr.OAuthProfile | None, *
98
  "blurb": blurb,
99
  "vote": vote,
100
  "user_id": user_id,
 
101
  }
102
  with scheduler.lock:
103
  with dataset_file.open("a") as f:
104
  f.write(json.dumps(log_entry) + "\n")
105
-
106
  gr.Info("Thank you for voting! Your feedback will be synced to the dataset.")
107
  return f"Logged: {vote} by user {user_id}", gr.Row.update(visible=False)
108
 
 
109
  # Create custom theme
110
  tufte_theme = TufteInspired()
111
 
@@ -125,6 +151,7 @@ with gr.Blocks(theme=tufte_theme) as demo:
125
  prompt_state = gr.State()
126
  blurb_output = gr.Markdown(label="Book blurb")
127
  user_state = gr.State()
 
128
 
129
  with gr.Row(visible=False) as voting_row:
130
  upvote_btn = gr.Button("πŸ‘ would read")
@@ -133,20 +160,21 @@ with gr.Blocks(theme=tufte_theme) as demo:
133
  vote_output = gr.Textbox(label="Vote Status", interactive=False, visible=True)
134
 
135
  def generate_and_show(prompt, user_info):
136
- # Optionally clear votes for the previous blurb if needed
137
- # global votes
138
- # votes = {k: v for k, v in votes.items() if not k.endswith(hash(previous_blurb))}
139
- return "Generating...", gr.Row.update(visible=False), user_info
140
 
141
- def show_voting_buttons(blurb):
142
- return blurb, gr.Row.update(visible=True)
143
 
144
  generate_btn.click(get_and_store_prompt, outputs=prompt_state).then(
145
  generate_and_show,
146
  inputs=[prompt_state, login_btn],
147
- outputs=[blurb_output, voting_row, user_state],
148
- ).then(generate_blurb, inputs=prompt_state, outputs=blurb_output).then(
149
- show_voting_buttons, inputs=blurb_output, outputs=[blurb_output, voting_row]
 
 
 
 
150
  )
151
 
152
  upvote_btn.click(
@@ -156,6 +184,7 @@ with gr.Blocks(theme=tufte_theme) as demo:
156
  blurb_output,
157
  gr.Textbox(value="upvote", visible=False),
158
  user_state,
 
159
  ],
160
  outputs=[vote_output, voting_row],
161
  )
@@ -166,6 +195,7 @@ with gr.Blocks(theme=tufte_theme) as demo:
166
  blurb_output,
167
  gr.Textbox(value="downvote", visible=False),
168
  user_state,
 
169
  ],
170
  outputs=[vote_output, voting_row],
171
  )
 
16
  # Ensure you're logged in to Hugging Face
17
  login(os.getenv("HF_TOKEN"))
18
 
19
+ # Define available models
20
+ MODELS = [
21
+ "meta-llama/Meta-Llama-3-70B-Instruct",
22
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
23
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
24
+ ]
25
 
26
 
27
+ def get_random_model():
28
+ return random.choice(MODELS)
29
+
30
+
31
+ def create_client(model_id):
32
+ return OpenAI(
33
+ base_url=f"https://api-inference.huggingface.co/models/{model_id}/v1",
34
+ api_key=get_token(),
35
+ )
36
+
37
 
38
  # Set up dataset storage
39
  dataset_folder = Path("dataset")
40
  dataset_folder.mkdir(exist_ok=True)
41
 
42
+
43
  # Function to get the latest dataset file
44
  def get_latest_dataset_file():
45
  files = list(dataset_folder.glob("data_*.jsonl"))
46
  return max(files, key=os.path.getctime) if files else None
47
 
48
+
49
  # Check for existing dataset and create or append to it
50
  if latest_file := get_latest_dataset_file():
51
  dataset_file = latest_file
 
67
  # Global dictionary to store votes
68
  votes = {}
69
 
70
+
71
  def generate_prompt():
72
  if random.choice([True, False]):
73
  return detailed_genre_description_prompt()
74
  else:
75
  return basic_prompt()
76
 
77
+
78
  def get_and_store_prompt():
79
  prompt = generate_prompt()
80
  print(prompt) # Keep this for debugging
81
  return prompt
82
 
83
+
84
  def generate_blurb(prompt):
85
+ model_id = get_random_model()
86
+ client = create_client(model_id)
87
  max_tokens = random.randint(100, 1000)
88
  chat_completion = client.chat.completions.create(
89
  model="tgi",
 
96
  full_text = ""
97
  for message in chat_completion:
98
  full_text += message.choices[0].delta.content
99
+ yield full_text, model_id
100
+ return full_text, model_id # Return final result with model_id
101
+
102
 
103
  def generate_vote_id(user_id, blurb):
104
  return hashlib.md5(f"{user_id}:{blurb}".encode()).hexdigest()
105
 
106
+
107
+ def log_blurb_and_vote(
108
+ prompt, blurb, vote, user_info: gr.OAuthProfile | None, model_id, *args
109
+ ):
110
  user_id = user_info.username if user_info is not None else str(uuid.uuid4())
111
  vote_id = generate_vote_id(user_id, blurb)
112
 
 
122
  "blurb": blurb,
123
  "vote": vote,
124
  "user_id": user_id,
125
+ "model_id": model_id,
126
  }
127
  with scheduler.lock:
128
  with dataset_file.open("a") as f:
129
  f.write(json.dumps(log_entry) + "\n")
130
+
131
  gr.Info("Thank you for voting! Your feedback will be synced to the dataset.")
132
  return f"Logged: {vote} by user {user_id}", gr.Row.update(visible=False)
133
 
134
+
135
  # Create custom theme
136
  tufte_theme = TufteInspired()
137
 
 
151
  prompt_state = gr.State()
152
  blurb_output = gr.Markdown(label="Book blurb")
153
  user_state = gr.State()
154
+ model_state = gr.State()
155
 
156
  with gr.Row(visible=False) as voting_row:
157
  upvote_btn = gr.Button("πŸ‘ would read")
 
160
  vote_output = gr.Textbox(label="Vote Status", interactive=False, visible=True)
161
 
162
  def generate_and_show(prompt, user_info):
163
+ return "Generating...", gr.Row.update(visible=False), user_info, None
 
 
 
164
 
165
+ def show_voting_buttons(blurb, model_id):
166
+ return blurb, gr.Row.update(visible=True), model_id
167
 
168
  generate_btn.click(get_and_store_prompt, outputs=prompt_state).then(
169
  generate_and_show,
170
  inputs=[prompt_state, login_btn],
171
+ outputs=[blurb_output, voting_row, user_state, model_state],
172
+ ).then(
173
+ generate_blurb, inputs=prompt_state, outputs=[blurb_output, model_state]
174
+ ).then(
175
+ show_voting_buttons,
176
+ inputs=[blurb_output, model_state],
177
+ outputs=[blurb_output, voting_row, model_state],
178
  )
179
 
180
  upvote_btn.click(
 
184
  blurb_output,
185
  gr.Textbox(value="upvote", visible=False),
186
  user_state,
187
+ model_state,
188
  ],
189
  outputs=[vote_output, voting_row],
190
  )
 
195
  blurb_output,
196
  gr.Textbox(value="downvote", visible=False),
197
  user_state,
198
+ model_state,
199
  ],
200
  outputs=[vote_output, voting_row],
201
  )