DuongTrongChi commited on
Commit
673210b
·
1 Parent(s): 9344443

setup-project

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
+
11
+ MAX_MAX_NEW_TOKENS = 2048
12
+ DEFAULT_MAX_NEW_TOKENS = 1024
13
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
14
+
15
+
16
+ # DESCRIPTION = ""
17
+ # if not torch.cuda.is_available():
18
+ # DESCRIPTION = "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
19
+
20
+
21
+ if torch.cuda.is_available():
22
+ device = torch.device("cuda")
23
+ print('There are %d GPU(s) available.' % torch.cuda.device_count())
24
+ print('We will use the GPU:', torch.cuda.get_device_name(0))
25
+ else:
26
+ print('No GPU available, using the CPU instead.')
27
+ device = torch.device("cpu")
28
+
29
+
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained("Back-up/T5-pretrain")
32
+ model = AutoModelForSeq2SeqLM.from_pretrained("Back-up/T5-large-QA")
33
+ model.to(device)
34
+
35
+
36
+ @spaces.GPU
37
+ def generate(
38
+ message: str,
39
+ chat_history: list[tuple[str, str]],
40
+ system_prompt: str,
41
+ max_new_tokens: int = 1024,
42
+ temperature: float = 0.6,
43
+ top_p: float = 0.9,
44
+ top_k: int = 50,
45
+ repetition_penalty: float = 1.2,
46
+ ) -> Iterator[str]:
47
+ tokenized_text = tokenizer.encode(message, return_tensors="pt").to(model.device)
48
+
49
+ model.eval()
50
+ summary_ids = model.generate(
51
+ tokenized_text,
52
+ max_length=1024,
53
+ min_length=8,
54
+ num_beams=5,
55
+ repetition_penalty=2.5,
56
+ length_penalty=1.0,
57
+ early_stopping=True
58
+ )
59
+ output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
60
+ yield output
61
+
62
+ chat_interface = gr.ChatInterface(
63
+ fn=generate,
64
+ additional_inputs=[
65
+ gr.Textbox(label="System prompt", lines=6),
66
+ gr.Slider(
67
+ label="Max new tokens",
68
+ minimum=1,
69
+ maximum=MAX_MAX_NEW_TOKENS,
70
+ step=1,
71
+ value=DEFAULT_MAX_NEW_TOKENS,
72
+ ),
73
+ gr.Slider(
74
+ label="Temperature",
75
+ minimum=0.1,
76
+ maximum=4.0,
77
+ step=0.1,
78
+ value=0.6,
79
+ ),
80
+ gr.Slider(
81
+ label="Top-p (nucleus sampling)",
82
+ minimum=0.05,
83
+ maximum=1.0,
84
+ step=0.05,
85
+ value=0.9,
86
+ ),
87
+ gr.Slider(
88
+ label="Top-k",
89
+ minimum=1,
90
+ maximum=1000,
91
+ step=1,
92
+ value=50,
93
+ ),
94
+ gr.Slider(
95
+ label="Repetition penalty",
96
+ minimum=1.0,
97
+ maximum=2.0,
98
+ step=0.05,
99
+ value=1.2,
100
+ ),
101
+ ],
102
+ stop_btn=None,
103
+ examples=[
104
+ ["Trường đại học Nông Lâm thành phố Hồ Chí Minh nằm ở đâu?"],
105
+ ["Mục tiêu chiến lược của trường đại học Nông Lâm thành phố Hồ Chí Minh là gì?"],
106
+ ["Sinh viên được khen thưởng cá nhân và tập thể khi nào?"],
107
+ ["Điều kiện cơ bản để được hỗ trợ vay tiền sinh viên là gì?"],
108
+ ["Trường Đại học Nông Lâm đã trải qua bao nhiêu năm hoạt động tính đến năm 2023?"],
109
+ ["Những hành vi nào của sinh viên bị coi là vi phạm quy định của Nhà trường?"],
110
+ ["Địa chỉ của Phân hiệu Trường Đại học Nông Lâm tại Ninh Thuận?"],
111
+ ["Làm thế nào khi sinh viên không hài lòng với việc giải quyết thắc mắc của Trưởng Bộ môn?"],
112
+ ["Làm thế để yêu cầu phúc khảo bài thi?"],
113
+ ["Nghĩa vụ của sinh viên là gì?"],
114
+ ["Viết cho tôi một chương trình tính số nguyên tố bằng python."]
115
+ ],
116
+ )
117
+
118
+ with gr.Blocks(css="style.css") as demo:
119
+ chat_interface.render()
120
+
121
+
122
+ if __name__ == "__main__":
123
+ demo.queue(max_size=20).launch()