groloch commited on
Commit
c48bfa8
·
1 Parent(s): f276860

Add workaround for gated models

Browse files
Files changed (2) hide show
  1. README.md +0 -5
  2. app.py +27 -3
README.md CHANGED
@@ -9,11 +9,6 @@ app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  short_description: Prompt enhancing models interface
12
-
13
- hf_oauth: true
14
- hf_oauth_scopes:
15
- - read-repos
16
- - manage-repos
17
  ---
18
 
19
  A playground to test and compare several prompt enhancing models.
 
9
  pinned: false
10
  license: apache-2.0
11
  short_description: Prompt enhancing models interface
 
 
 
 
 
12
  ---
13
 
14
  A playground to test and compare several prompt enhancing models.
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
 
5
 
6
  choices_base_models = {
@@ -17,11 +18,18 @@ choices_gen_token = {
17
  'groloch/Ministral-3b-instruct-PromptEnhancing': 'ministral/Ministral-3b-instruct'
18
  }
19
 
 
 
 
 
 
20
  previous_choice = ''
21
 
22
  model = None
23
  tokenizer = None
24
 
 
 
25
 
26
  def load_model(adapter_repo_id: str):
27
  global model, tokenizer
@@ -37,7 +45,8 @@ def generate(prompt_to_enhance: str,
37
  max_tokens: float,
38
  temperature: float,
39
  top_p: float,
40
- repetition_penalty: float
 
41
  ):
42
  if prompt_to_enhance is None or prompt_to_enhance == '':
43
  raise gr.Error('Please enter a prompt')
@@ -47,6 +56,15 @@ def generate(prompt_to_enhance: str,
47
  previous_choice = choice
48
  load_model(choice)
49
 
 
 
 
 
 
 
 
 
 
50
  chat = [
51
  {'role' : 'user', 'content': prompt_to_enhance}
52
  ]
@@ -124,17 +142,23 @@ input_repetition_penalty = gr.Number(
124
  maximum=5.0,
125
  step=0.1
126
  )
 
 
 
 
127
 
128
  demo = gr.Interface(
129
  generate,
130
  title='Prompt Enhancing Playground',
131
  description='This space is a tool to compare the different prompt enhancing model I have finetuned. \
132
- Feel free to experiment as you want !',
 
133
  inputs=[input_prompt, model_choice],
134
  additional_inputs=[input_max_tokens,
135
  input_temperature,
136
  input_top_p,
137
- input_repetition_penalty
 
138
  ],
139
  outputs=['text']
140
  )
 
1
  import torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from huggingface_hub import login
5
 
6
 
7
  choices_base_models = {
 
18
  'groloch/Ministral-3b-instruct-PromptEnhancing': 'ministral/Ministral-3b-instruct'
19
  }
20
 
21
+ gated_models = [
22
+ 'groloch/Llama-3.2-3B-Instruct-PromptEnhancing',
23
+ 'groloch/gemma-2-2b-it-PromptEnhancing'
24
+ ]
25
+
26
  previous_choice = ''
27
 
28
  model = None
29
  tokenizer = None
30
 
31
+ logged_in = False
32
+
33
 
34
  def load_model(adapter_repo_id: str):
35
  global model, tokenizer
 
45
  max_tokens: float,
46
  temperature: float,
47
  top_p: float,
48
+ repetition_penalty: float,
49
+ access_token: str
50
  ):
51
  if prompt_to_enhance is None or prompt_to_enhance == '':
52
  raise gr.Error('Please enter a prompt')
 
56
  previous_choice = choice
57
  load_model(choice)
58
 
59
+ if choice in gated_models and access_token == '':
60
+ raise gr.Error(f'Please enter your access token (in Additional inputs) if youre using one of the following \
61
+ models: {', '.join(gated_models)}. Make sure you have access to those models.')
62
+
63
+ global logged_in
64
+ if not logged_in and choice in gated_models:
65
+ login(access_token)
66
+ logged_in = True
67
+
68
  chat = [
69
  {'role' : 'user', 'content': prompt_to_enhance}
70
  ]
 
142
  maximum=5.0,
143
  step=0.1
144
  )
145
+ input_access_token = gr.Text(
146
+ label='Access token for gated models',
147
+ value=''
148
+ )
149
 
150
  demo = gr.Interface(
151
  generate,
152
  title='Prompt Enhancing Playground',
153
  description='This space is a tool to compare the different prompt enhancing model I have finetuned. \
154
+ Feel free to experiment as you want ! \n\
155
+ If you want to use this locally, you can download the gpu version (see in files)',
156
  inputs=[input_prompt, model_choice],
157
  additional_inputs=[input_max_tokens,
158
  input_temperature,
159
  input_top_p,
160
+ input_repetition_penalty,
161
+ input_access_token
162
  ],
163
  outputs=['text']
164
  )