qgyd2021 commited on
Commit
ed6ea08
·
1 Parent(s): 151f498

[update]add code

Browse files
Files changed (5) hide show
  1. .gitignore +6 -0
  2. README.md +6 -7
  3. main.py +127 -0
  4. project_settings.py +12 -0
  5. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ .git/
3
+ .idea/
4
+
5
+ **/flagged/
6
+ **/__pycache__/
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
- title: Gpt2 Chat
3
- emoji: 📊
4
- colorFrom: green
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.50.2
8
- app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: GPT2 Chat
3
+ emoji: 🐠
4
+ colorFrom: purple
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.41.2
8
+ app_file: main.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
main.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import defaultdict
5
+ import os
6
+
7
+ import gradio as gr
8
+ from threading import Thread
9
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
10
+ from transformers.models.bert.tokenization_bert import BertTokenizer
11
+ from transformers.generation.streamers import TextIteratorStreamer
12
+ import torch
13
+
14
+ from project_settings import project_path
15
+
16
+
17
+ def get_args():
18
+ parser = argparse.ArgumentParser()
19
+
20
+ parser.add_argument("--max_new_tokens", default=512, type=int)
21
+ parser.add_argument("--top_p", default=0.9, type=float)
22
+ parser.add_argument("--temperature", default=0.35, type=float)
23
+ parser.add_argument("--repetition_penalty", default=1.0, type=float)
24
+ parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str)
25
+
26
+ args = parser.parse_args()
27
+ return args
28
+
29
+
30
+ description = """
31
+ ## GPT2 Chat
32
+ """
33
+
34
+
35
+ examples = [
36
+
37
+ ]
38
+
39
+
40
+ def main():
41
+ args = get_args()
42
+
43
+ if args.device == 'auto':
44
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
+ else:
46
+ device = args.device
47
+
48
+ input_text_box = gr.Text(label="text")
49
+ output_text_box = gr.Text(lines=4, label="generated_content")
50
+
51
+ def fn_stream(text: str,
52
+ max_new_tokens: int = 200,
53
+ top_p: float = 0.85,
54
+ temperature: float = 0.35,
55
+ repetition_penalty: float = 1.2,
56
+ model_name: str = "qgyd2021/lib_service_4chan",
57
+ is_chat: bool = True,
58
+ ):
59
+ tokenizer = BertTokenizer.from_pretrained(model_name)
60
+ model = GPT2LMHeadModel.from_pretrained(model_name)
61
+ model = model.eval()
62
+
63
+ text_encoded = tokenizer.__call__(text, add_special_tokens=False)
64
+ input_ids_ = text_encoded["input_ids"]
65
+
66
+ input_ids = [tokenizer.cls_token_id]
67
+ input_ids.extend(input_ids_)
68
+ if is_chat:
69
+ input_ids.append(tokenizer.sep_token_id)
70
+
71
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
72
+ input_ids = input_ids.to(device)
73
+
74
+ output = ""
75
+ streamer = TextIteratorStreamer(tokenizer=tokenizer)
76
+
77
+ generation_kwargs = dict(
78
+ inputs=input_ids,
79
+ max_new_tokens=max_new_tokens,
80
+ do_sample=True,
81
+ top_p=top_p,
82
+ temperature=temperature,
83
+ repetition_penalty=repetition_penalty,
84
+ eos_token_id=tokenizer.sep_token_id if is_chat else None,
85
+ pad_token_id=tokenizer.pad_token_id,
86
+ streamer=streamer,
87
+ )
88
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
89
+ thread.start()
90
+
91
+ for output_ in streamer:
92
+ output_ = output_.replace(" ", "")
93
+ output_ = output_.replace("[CLS]", "")
94
+ output_ = output_.replace("[SEP]", "\n")
95
+ output_ = output_.replace("[UNK]", "")
96
+
97
+ output += output_
98
+ output_text_box.value += output
99
+ yield output
100
+
101
+ demo = gr.Interface(
102
+ fn=fn_stream,
103
+ inputs=[
104
+ input_text_box,
105
+ gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"),
106
+ gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
107
+ gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
108
+ gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
109
+ gr.Dropdown(choices=["qgyd2021/lib_service_4chan"], label="model_name"),
110
+ gr.Checkbox(label="is_chat")
111
+ ],
112
+ outputs=[output_text_box],
113
+ examples=[
114
+ ["怎样擦屁股才能擦的干净", 512, 0.75, 0.35, 1.2, "qgyd2021/lib_service_4chan", True],
115
+ ],
116
+ cache_examples=False,
117
+ examples_per_page=50,
118
+ title="H Novel Generate",
119
+ description=description,
120
+ )
121
+ demo.queue().launch()
122
+
123
+ return
124
+
125
+
126
+ if __name__ == '__main__':
127
+ main()
project_settings.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from pathlib import Path
5
+
6
+
7
+ project_path = os.path.abspath(os.path.dirname(__file__))
8
+ project_path = Path(project_path)
9
+
10
+
11
+ if __name__ == '__main__':
12
+ pass
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.41.2
2
+ pydantic==1.10.12
3
+ thinc==7.4.6
4
+ spacy==2.3.9
5
+ transformers==4.30.2
6
+ numpy==1.21.4
7
+ tqdm==4.62.3
8
+ torch==1.13.0
9
+ datasets