pszemraj commited on
Commit
58acd65
·
verified ·
1 Parent(s): 8120b87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  import gradio as gr
12
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
13
 
14
- model_id = "BEE-spoke-data/tFINE-900m-e16-d32-instruct"
15
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
16
  logging.info(f"Running on device:\t {torch_device}")
17
  logging.info(f"CPU threads:\t {torch.get_num_threads()}")
@@ -22,7 +22,7 @@ if torch_device == "cuda":
22
  model_id, load_in_8bit=True, device_map="auto"
23
  )
24
  else:
25
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
26
  try:
27
  model = torch.compile(model)
28
  except Exception as e:
@@ -165,4 +165,4 @@ with gr.Blocks() as demo:
165
  model_output,
166
  )
167
 
168
- demo.queue(max_size=32).launch()
 
11
  import gradio as gr
12
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
13
 
14
+ model_id = "BEE-spoke-data/tFINE-900m-e16-d32-instruct_2e"
15
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
16
  logging.info(f"Running on device:\t {torch_device}")
17
  logging.info(f"CPU threads:\t {torch.get_num_threads()}")
 
22
  model_id, load_in_8bit=True, device_map="auto"
23
  )
24
  else:
25
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
26
  try:
27
  model = torch.compile(model)
28
  except Exception as e:
 
165
  model_output,
166
  )
167
 
168
+ demo.queue(max_size=10).launch()