pszemraj commited on
Commit
ea9c426
·
verified ·
1 Parent(s): bf2a5db

update model

Browse files
Files changed (1) hide show
  1. app.py +3 -9
app.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  import gradio as gr
12
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
13
 
14
- model_id = "pszemraj/nanoT5-mid-2k-instruct"
15
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
16
  logging.info(f"Running on device:\t {torch_device}")
17
  logging.info(f"CPU threads:\t {torch.get_num_threads()}")
@@ -63,6 +63,7 @@ def run_generation(
63
  repetition_penalty=repetition_penalty,
64
  length_penalty=length_penalty,
65
  no_repeat_ngram_size=no_repeat_ngram_size,
 
66
  )
67
  t = Thread(target=model.generate, kwargs=generate_kwargs)
68
  t.start()
@@ -152,14 +153,7 @@ with gr.Blocks() as demo:
152
  interactive=True,
153
  label="Length Penalty",
154
  )
155
- # temperature = gr.Slider(
156
- # minimum=0.1,
157
- # maximum=5.0,
158
- # value=0.8,
159
- # step=0.1,
160
- # interactive=True,
161
- # label="Temperature",
162
- # )
163
  user_text.submit(
164
  run_generation,
165
  [user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty, length_penalty],
 
11
  import gradio as gr
12
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
13
 
14
+ model_id = "BEE-spoke-data/tFINE-900m-e16-d32-instruct"
15
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
16
  logging.info(f"Running on device:\t {torch_device}")
17
  logging.info(f"CPU threads:\t {torch.get_num_threads()}")
 
63
  repetition_penalty=repetition_penalty,
64
  length_penalty=length_penalty,
65
  no_repeat_ngram_size=no_repeat_ngram_size,
66
+ renormalize_logits=True,
67
  )
68
  t = Thread(target=model.generate, kwargs=generate_kwargs)
69
  t.start()
 
153
  interactive=True,
154
  label="Length Penalty",
155
  )
156
+
 
 
 
 
 
 
 
157
  user_text.submit(
158
  run_generation,
159
  [user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty, length_penalty],