vilarin commited on
Commit
b42e029
·
verified ·
1 Parent(s): 022d2fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -9,7 +9,8 @@ from threading import Thread
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL_ID = "CohereForAI/aya-23-8B"
12
- MODEL_NAME = MODEL_ID.split("/")[-1]
 
13
 
14
  TITLE = "<h1><center>Aya-23-Chatbox</center></h1>"
15
 
@@ -34,26 +35,27 @@ USE_FLASH_ATTENTION = False
34
  GRAD_ACC_STEPS = 16
35
 
36
  quantization_config = None
 
37
  if QUANTIZE_4BIT:
38
- quantization_config = BitsAndBytesConfig(
39
- load_in_4bit=True,
40
- bnb_4bit_quant_type="nf4",
41
- bnb_4bit_use_double_quant=True,
42
- bnb_4bit_compute_dtype=torch.bfloat16,
43
- )
44
 
45
  attn_implementation = None
46
  if USE_FLASH_ATTENTION:
47
- attn_implementation="flash_attention_2"
48
 
49
  model = AutoModelForCausalLM.from_pretrained(
50
- MODEL_ID,
51
  quantization_config=quantization_config,
52
  attn_implementation=attn_implementation,
53
  torch_dtype=torch.bfloat16,
54
  device_map="auto",
55
  )
56
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
57
 
58
  @spaces.GPU
59
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL_ID = "CohereForAI/aya-23-8B"
12
+ MODEL_ID2 = "CohereForAI/aya-23-35B"
13
+ MODEL_NAME = MODEL_ID2.split("/")[-1]
14
 
15
  TITLE = "<h1><center>Aya-23-Chatbox</center></h1>"
16
 
 
35
  GRAD_ACC_STEPS = 16
36
 
37
  quantization_config = None
38
+
39
  if QUANTIZE_4BIT:
40
+ quantization_config = BitsAndBytesConfig(
41
+ load_in_4bit=True,
42
+ bnb_4bit_quant_type="nf4",
43
+ bnb_4bit_use_double_quant=True,
44
+ bnb_4bit_compute_dtype=torch.bfloat16,
45
+ )
46
 
47
  attn_implementation = None
48
  if USE_FLASH_ATTENTION:
49
+ attn_implementation="flash_attention_2"
50
 
51
  model = AutoModelForCausalLM.from_pretrained(
52
+ MODEL_ID2,
53
  quantization_config=quantization_config,
54
  attn_implementation=attn_implementation,
55
  torch_dtype=torch.bfloat16,
56
  device_map="auto",
57
  )
58
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID2)
59
 
60
  @spaces.GPU
61
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):