jslin09 commited on
Commit
03746f8
·
1 Parent(s): 1578262

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from peft import PeftModel, PeftConfig
3
+ import transformers
4
+ import gradio as gr
5
+
6
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM, GenerationConfig
7
+ from transformers.models.opt.modeling_opt import OPTDecoderLayer
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom')
10
+
11
+ BASE_MODEL = "bigscience/bloom-3b"
12
+
13
+ #LORA_WEIGHTS = f"/content/drive/MyDrive/Colab Notebooks/LegalChatbot-{model_name}"
14
+ LORA_WEIGHTS = f"jslin09/LegalChatbot-bloom-3b"
15
+
16
+ config = PeftConfig.from_pretrained(LORA_WEIGHTS)
17
+
18
+ if torch.cuda.is_available():
19
+ device = "cuda"
20
+ else:
21
+ device = "cpu"
22
+
23
+ try:
24
+ if torch.backends.mps.is_available():
25
+ device = "mps"
26
+ except:
27
+ pass
28
+
29
+ if device == "cuda":
30
+ model = BloomForCausalLM.from_pretrained(
31
+ BASE_MODEL,
32
+ load_in_8bit=True,
33
+ torch_dtype=torch.float16,
34
+ device_map="auto",
35
+ )
36
+ model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
37
+ elif device == "mps":
38
+ model = BloomForCausalLM.from_pretrained(
39
+ BASE_MODEL,
40
+ device_map={"": device},
41
+ torch_dtype=torch.float16,
42
+ )
43
+ model = PeftModel.from_pretrained(
44
+ model,
45
+ LORA_WEIGHTS,
46
+ device_map={"": device},
47
+ torch_dtype=torch.float16,
48
+ )
49
+ else:
50
+ model = BloomForCausalLM.from_pretrained(
51
+ BASE_MODEL, device_map={"": device},
52
+ low_cpu_mem_usage=True
53
+ )
54
+ model = PeftModel.from_pretrained(
55
+ model,
56
+ LORA_WEIGHTS,
57
+ device_map={"": device},
58
+ )
59
+
60
+
61
+ def generate_prompt(instruction, input=None):
62
+ if input:
63
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
64
+
65
+ ### Instruction:
66
+ {instruction}
67
+
68
+ ### Input:
69
+ {input}
70
+
71
+ ### Response:"""
72
+ else:
73
+ return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
74
+
75
+ ### Instruction:
76
+ {instruction}
77
+
78
+ ### Response:"""
79
+
80
+ def generate_prompt_tw(instruction, input=None):
81
+ if input:
82
+ return f"""以下是描述任務的指令,並與提供進一步上下文的輸入配對。編寫適當完成請求的回應。
83
+
84
+ ### 指令:
85
+ {instruction}
86
+
87
+ ### 輸入:
88
+ {input}
89
+
90
+ ### 回應:"""
91
+ else:
92
+ return f"""以下是描述任務的指令。編寫適當完成請求的回應。
93
+
94
+ ### 指令:
95
+ {instruction}
96
+
97
+ ### 回應:"""
98
+
99
+
100
+ model.eval()
101
+ if torch.__version__ >= "2":
102
+ model = torch.compile(model)
103
+
104
+
105
+ def evaluate(
106
+ instruction,
107
+ input=None,
108
+ temperature=0.1,
109
+ top_p=0.75,
110
+ top_k=40,
111
+ num_beams=4,
112
+ max_new_tokens=128,
113
+ **kwargs,
114
+ ):
115
+ prompt = generate_prompt_tw(instruction, input) # 中文版的話,函數名稱要改用 generate_prompt_tw
116
+ inputs = tokenizer(prompt, return_tensors="pt")
117
+ input_ids = inputs["input_ids"].to(device)
118
+ generation_config = GenerationConfig(
119
+ temperature=temperature,
120
+ top_p=top_p,
121
+ top_k=top_k,
122
+ num_beams=num_beams,
123
+ **kwargs,
124
+ )
125
+ with torch.no_grad():
126
+ generation_output = model.generate(
127
+ input_ids=input_ids,
128
+ generation_config=generation_config,
129
+ return_dict_in_generate=True,
130
+ output_scores=True,
131
+ max_new_tokens=max_new_tokens,
132
+ )
133
+ s = generation_output.sequences[0]
134
+ output = tokenizer.decode(s)
135
+ # return output.split("### Response:")[1].strip() # 中文版的話,要改為 return output.split("### 回應:")[1].strip()
136
+ return output.split("### 回應:")[1].strip()
137
+
138
+
139
+ gr.Interface(
140
+ fn=evaluate,
141
+ inputs=[
142
+ gr.components.Textbox(
143
+ lines=2, label="Instruction", placeholder="Tell me about alpacas."
144
+ ),
145
+ gr.components.Textbox(lines=2, label="Input", placeholder="none"),
146
+ gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
147
+ gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
148
+ gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
149
+ gr.components.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
150
+ gr.components.Slider(
151
+ minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
152
+ ),
153
+ ],
154
+ outputs=[
155
+ gr.components.Textbox(
156
+ lines=5,
157
+ label="Output",
158
+ )
159
+ ],
160
+ title="🌲 🌲 🌲 BLOOM-LoRA-LegalChatbot",
161
+ description="BLOOM-LoRA-LegalChatbot is a 3B-parameter BLOOM model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and my Legal QA dataset, and makes use of the Huggingface BLOOM implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).",
162
+ ).launch()