Yurii Paniv commited on
Commit
6abca35
·
1 Parent(s): a94d4e4

Add paper config

Browse files
Files changed (1) hide show
  1. app.py +58 -30
app.py CHANGED
@@ -1,7 +1,13 @@
1
  import gradio as gr
2
 
3
  from peft import PeftModel, PeftConfig
4
- from transformers import MistralForCausalLM, TextIteratorStreamer, AutoTokenizer, BitsAndBytesConfig
 
 
 
 
 
 
5
  from time import sleep
6
  from threading import Thread
7
  from torch import float16
@@ -15,38 +21,60 @@ quant_config = BitsAndBytesConfig(
15
  bnb_4bit_use_double_quant=False,
16
  )
17
 
18
- model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1",
19
- quantization_config=quant_config)
20
- #device_map="auto",)
 
 
 
21
  model = PeftModel.from_pretrained(model, "lang-uk/dragoman").to("cuda")
22
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False, add_bos_token=False)
 
 
 
23
 
24
  @spaces.GPU(duration=30)
25
  def translate(input_text):
26
  generated_text = ""
27
  input_text = input_text.strip()
28
- for chunk in input_text.split("\n"):
29
- if not chunk:
30
- generated_text += "\n"
31
- yield generated_text
32
- continue
33
- chunk = f"[INST] {chunk} [/INST]"
34
- inputs = tokenizer([chunk], return_tensors="pt").to(model.device)
35
-
36
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
37
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=200)
38
-
39
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
40
-
41
- thread.start()
42
-
43
- for new_text in streamer:
44
- generated_text += new_text
45
- yield generated_text
46
-
47
- generated_text += "\n"
48
- yield generated_text
49
-
50
-
51
- iface = gr.Interface(fn=translate, inputs="text", outputs="text", examples=[["who holds this neighborhood?"]])
52
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
  from peft import PeftModel, PeftConfig
4
+ from transformers import (
5
+ MistralForCausalLM,
6
+ TextIteratorStreamer,
7
+ AutoTokenizer,
8
+ BitsAndBytesConfig,
9
+ GenerationConfig,
10
+ )
11
  from time import sleep
12
  from threading import Thread
13
  from torch import float16
 
21
  bnb_4bit_use_double_quant=False,
22
  )
23
 
24
+ model = MistralForCausalLM.from_pretrained(
25
+ "mistralai/Mistral-7B-v0.1", quantization_config=quant_config
26
+ )
27
+
28
+
29
+ # device_map="auto",)
30
  model = PeftModel.from_pretrained(model, "lang-uk/dragoman").to("cuda")
31
+ tokenizer = AutoTokenizer.from_pretrained(
32
+ "mistralai/Mistral-7B-v0.1", use_fast=False, add_bos_token=False
33
+ )
34
+
35
 
36
  @spaces.GPU(duration=30)
37
  def translate(input_text):
38
  generated_text = ""
39
  input_text = input_text.strip()
40
+
41
+ input_text = f"[INST] {input_text} [/INST]"
42
+ inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
43
+
44
+ generation_kwargs = dict(inputs, max_new_tokens=200, num_beams=10, temperature=1) # streamer=streamer,
45
+
46
+
47
+ # streaming support
48
+ #streamer = TextIteratorStreamer(
49
+ # tokenizer, skip_prompt=True, skip_special_tokens=True
50
+ #)
51
+
52
+ #thread = Thread(target=model.generate, kwargs=generation_kwargs)
53
+
54
+ #thread.start()
55
+
56
+
57
+ #for new_text in streamer:
58
+ # generated_text += new_text
59
+ # yield generated_text
60
+
61
+ #generated_text += "\n"
62
+ #yield generated_text
63
+
64
+ output = model.generate(**generation_kwargs)
65
+ output = tokenizer.decode(output[0], skip_special_tokens=True).split("[/INST] ")[-1].strip()
66
+ return output
67
+
68
+
69
+ iface = gr.Interface(
70
+ fn=translate,
71
+ inputs=gr.Textbox(
72
+ value="",
73
+ label="Source sentence",
74
+ ),
75
+ outputs=gr.Textbox(label="Translated sentence"),
76
+ examples=[[
77
+ "ChatGPT (Chat Generative Pre-trained Transformer) is a chatbot developed by OpenAI and launched on November 30, 2022. Based on a large language model, it enables users to refine and steer a conversation towards a desired length, format, style, level of detail, and language. Successive prompts and replies, known as prompt engineering, are considered at each conversation stage as a context.[2] ",
78
+ "who holds this neighborhood?"]],
79
+ )
80
+ iface.launch()