vilarin commited on
Commit
bd34f0b
·
verified ·
1 Parent(s): 9784048

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -39
app.py CHANGED
@@ -2,64 +2,49 @@ import torch
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
- MODEL_ID = "CohereForAI/aya-23-8B"
12
- MODEL_ID2 = "CohereForAI/aya-23-35B"
13
  MODELS = os.environ.get("MODELS")
14
  MODEL_NAME = MODELS.split("/")[-1]
15
 
16
- TITLE = "<h1><center>Aya-23-Chatbox</center></h1>"
17
 
18
- DESCRIPTION = f'<h3><center>MODEL: <a href="https://hf.co/{MODELS}">{MODEL_NAME}</a></center></h3>'
 
 
 
 
 
 
 
 
19
 
20
  CSS = """
21
  .duplicate-button {
22
- margin: auto !important;
23
- color: white !important;
24
- background: black !important;
25
- border-radius: 100vh !important;
 
 
 
26
  }
27
  """
28
 
29
-
30
- #QUANTIZE
31
- QUANTIZE_4BIT = True
32
- USE_GRAD_CHECKPOINTING = True
33
- TRAIN_BATCH_SIZE = 2
34
- TRAIN_MAX_SEQ_LENGTH = 512
35
- USE_FLASH_ATTENTION = False
36
- GRAD_ACC_STEPS = 16
37
-
38
- quantization_config = None
39
-
40
- if QUANTIZE_4BIT:
41
- quantization_config = BitsAndBytesConfig(
42
- load_in_4bit=True,
43
- bnb_4bit_quant_type="nf4",
44
- bnb_4bit_use_double_quant=True,
45
- bnb_4bit_compute_dtype=torch.bfloat16,
46
- )
47
-
48
- attn_implementation = None
49
- if USE_FLASH_ATTENTION:
50
- attn_implementation="flash_attention_2"
51
-
52
  model = AutoModelForCausalLM.from_pretrained(
53
  MODELS,
54
- quantization_config=quantization_config,
55
- attn_implementation=attn_implementation,
56
- torch_dtype=torch.bfloat16,
57
  device_map="auto",
58
  )
59
  tokenizer = AutoTokenizer.from_pretrained(MODELS)
60
 
61
  @spaces.GPU
62
- def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
63
  print(f'message is - {message}')
64
  print(f'history is - {history}')
65
  conversation = []
@@ -69,16 +54,21 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
69
 
70
  print(f"Conversation is -\n{conversation}")
71
 
72
- input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
 
73
 
74
- streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
75
 
76
  generate_kwargs = dict(
77
- input_ids=input_ids,
78
  streamer=streamer,
 
 
 
79
  max_new_tokens=max_new_tokens,
80
  do_sample=True,
81
  temperature=temperature,
 
82
  )
83
 
84
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -119,6 +109,30 @@ with gr.Blocks(css=CSS) as demo:
119
  label="Max new tokens",
120
  render=False,
121
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  ],
123
  examples=[
124
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
 
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL_ID = "Qwen/Qwen1.5-7B-Chat"
 
12
  MODELS = os.environ.get("MODELS")
13
  MODEL_NAME = MODELS.split("/")[-1]
14
 
15
+ TITLE = "<h1><center>Qwen2-Chatbox</center></h1>"
16
 
17
+ DESCRIPTION = f"""
18
+ <h3>MODEL: <a href="https://hf.co/{MODELS}">{MODEL_NAME}</a></h3>
19
+ <center>
20
+ <p>Qwen is the large language model built by Alibaba Cloud.
21
+ <br>
22
+ Feel free to test without log.
23
+ </p>
24
+ </center>
25
+ """
26
 
27
  CSS = """
28
  .duplicate-button {
29
+ margin: auto !important;
30
+ color: white !important;
31
+ background: black !important;
32
+ border-radius: 100vh !important;
33
+ }
34
+ h3 {
35
+ text-align: center;
36
  }
37
  """
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  model = AutoModelForCausalLM.from_pretrained(
40
  MODELS,
41
+ torch_dtype=torch.float16,
 
 
42
  device_map="auto",
43
  )
44
  tokenizer = AutoTokenizer.from_pretrained(MODELS)
45
 
46
  @spaces.GPU
47
+ def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
48
  print(f'message is - {message}')
49
  print(f'history is - {history}')
50
  conversation = []
 
54
 
55
  print(f"Conversation is -\n{conversation}")
56
 
57
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
58
+ inputs = tokenizer(input_ids, return_tensors="pt").to(0)
59
 
60
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
61
 
62
  generate_kwargs = dict(
63
+ inputs,
64
  streamer=streamer,
65
+ top_k=top_k,
66
+ top_p=top_p,
67
+ repetition_penalty=penalty,
68
  max_new_tokens=max_new_tokens,
69
  do_sample=True,
70
  temperature=temperature,
71
+ eos_token_id = [151645, 151643],
72
  )
73
 
74
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
 
109
  label="Max new tokens",
110
  render=False,
111
  ),
112
+ gr.Slider(
113
+ minimum=0.0,
114
+ maximum=1.0,
115
+ step=0.1,
116
+ value=0.8,
117
+ label="top_p",
118
+ render=False,
119
+ ),
120
+ gr.Slider(
121
+ minimum=1,
122
+ maximum=20,
123
+ step=1,
124
+ value=20,
125
+ label="top_k",
126
+ render=False,
127
+ ),
128
+ gr.Slider(
129
+ minimum=0.0,
130
+ maximum=2.0,
131
+ step=0.1,
132
+ value=1.0,
133
+ label="Repetition penalty",
134
+ render=False,
135
+ ),
136
  ],
137
  examples=[
138
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],