qgyd2021 commited on
Commit
9f1d0ce
·
1 Parent(s): 9deaf72

[update]add model

Browse files
examples/exercises/chinese_porn_novel/1.prepare_data.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ import random
7
+ import sys
8
+
9
+ pwd = os.path.abspath(os.path.dirname(__file__))
10
+ sys.path.append(os.path.join(pwd, '../../../'))
11
+
12
+ from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
13
+ from tqdm import tqdm
14
+
15
+ from project_settings import project_path
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+
21
+ parser.add_argument("--dataset_path", default="qgyd2021/h_novel", type=str)
22
+ # parser.add_argument("--dataset_name", default="ltxsba_500m", type=str)
23
+ parser.add_argument("--dataset_name", default="ltxsba_5gb", type=str)
24
+ parser.add_argument("--dataset_split", default="train", type=str)
25
+ parser.add_argument(
26
+ "--dataset_cache_dir",
27
+ default=(project_path / "hub_datasets").as_posix(),
28
+ type=str
29
+ )
30
+ parser.add_argument("--train_subset", default="train.jsonl", type=str)
31
+ parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
32
+ args = parser.parse_args()
33
+ return args
34
+
35
+
36
+ def main():
37
+ args = get_args()
38
+
39
+ dataset_dict = load_dataset(
40
+ path=args.dataset_path,
41
+ name=args.dataset_name,
42
+ # split=args.dataset_split,
43
+ cache_dir=args.dataset_cache_dir,
44
+ streaming=True,
45
+ )
46
+
47
+ train_dataset = dataset_dict["train"]
48
+
49
+ with open(args.train_subset, "w", encoding="utf-8") as ftrain, \
50
+ open(args.valid_subset, "w", encoding="utf-8") as fvalid:
51
+ for sample in tqdm(train_dataset):
52
+ # print(sample)
53
+
54
+ source = sample["source"]
55
+ idx = sample["idx"]
56
+ filename = sample["filename"]
57
+ novel_name = sample["novel_name"]
58
+ row_idx = sample["row_idx"]
59
+ text = sample["text"]
60
+
61
+ row = {
62
+ "text": text
63
+ }
64
+ row = json.dumps(row, ensure_ascii=False)
65
+
66
+ if random.random() < 0.95:
67
+ ftrain.write("{}\n".format(row))
68
+ else:
69
+ fvalid.write("{}\n".format(row))
70
+
71
+ return
72
+
73
+
74
+ if __name__ == '__main__':
75
+ main()
examples/exercises/chinese_porn_novel/2.train_model.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 参考链接:
5
+ https://www.thepythoncode.com/article/pretraining-bert-huggingface-transformers-in-python
6
+ https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py
7
+
8
+ """
9
+ import argparse
10
+ from itertools import chain
11
+ import os
12
+ from pathlib import Path
13
+ import platform
14
+
15
+ from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
16
+ import torch
17
+ from transformers.data.data_collator import DataCollatorForLanguageModeling
18
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
19
+ from transformers.models.bert.tokenization_bert import BertTokenizer
20
+ from transformers.trainer import Trainer
21
+ from transformers.training_args import TrainingArguments
22
+
23
+ from project_settings import project_path
24
+
25
+
26
+ def get_args():
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument(
29
+ "--pretrained_model_name_or_path",
30
+ default=(project_path / "pretrained_models/gpt2-chinese-cluecorpussmall").as_posix(),
31
+ type=str
32
+ )
33
+
34
+ parser.add_argument("--train_subset", default="train.jsonl", type=str)
35
+ parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
36
+
37
+ parser.add_argument("--output_dir", default="serialization_dir", type=str)
38
+ parser.add_argument("--overwrite_output_dir", action="store_true")
39
+ parser.add_argument("--evaluation_strategy", default="no", choices=["no", "steps", "epoch"], type=str)
40
+ parser.add_argument("--per_device_train_batch_size", default=8, type=int)
41
+ parser.add_argument("--gradient_accumulation_steps", default=4, type=int)
42
+ parser.add_argument("--learning_rate", default=1e-5, type=float)
43
+ parser.add_argument("--weight_decay", default=0, type=float)
44
+ parser.add_argument("--max_grad_norm", default=1.0, type=float)
45
+ parser.add_argument("--num_train_epochs", default=3.0, type=float)
46
+ parser.add_argument("--max_steps", default=-1, type=int)
47
+ parser.add_argument("--lr_scheduler_type", default="cosine", type=str)
48
+ parser.add_argument("--warmup_ratio", default=0.0, type=float)
49
+ parser.add_argument("--warmup_steps", default=3000, type=int)
50
+ parser.add_argument("--logging_steps", default=300, type=int)
51
+ parser.add_argument("--save_strategy", default="steps", type=str)
52
+ parser.add_argument("--save_steps", default=500, type=int)
53
+ parser.add_argument("--save_total_limit", default=3, type=int)
54
+ parser.add_argument("--no_cuda", action="store_true")
55
+ parser.add_argument("--seed", default=3407, type=str, help="https://arxiv.org/abs/2109.08203")
56
+ # parser.add_argument("--fp16", action="store_true")
57
+ parser.add_argument("--fp16", action="store_false")
58
+ parser.add_argument("--half_precision_backend", default="auto", type=str)
59
+ parser.add_argument("--dataloader_num_workers", default=5, type=int)
60
+ parser.add_argument("--disable_tqdm", action="store_false")
61
+ parser.add_argument("--remove_unused_columns", action="store_false")
62
+ # parser.add_argument("--deepspeed", default="ds_z3_config.json", type=str)
63
+ parser.add_argument("--deepspeed", default=None, type=str)
64
+ parser.add_argument("--optim", default="adamw_hf", type=str)
65
+ parser.add_argument("--report_to", default="tensorboard", type=str)
66
+ parser.add_argument("--resume_from_checkpoint", default=None, type=str)
67
+ # parser.add_argument("--gradient_checkpointing", action="store_true")
68
+ parser.add_argument("--gradient_checkpointing", action="store_false")
69
+
70
+ parser.add_argument("--truncate_longer_samples", action="store_true")
71
+ # parser.add_argument("--truncate_longer_samples", action="store_false")
72
+ parser.add_argument("--max_seq_length", default=1024, type=int)
73
+
74
+ args = parser.parse_args()
75
+ return args
76
+
77
+
78
+ def main():
79
+ args = get_args()
80
+
81
+ # dataset
82
+ dataset_dict = DatasetDict()
83
+ train_data_files = [args.train_subset]
84
+ dataset_dict["train"] = load_dataset(
85
+ path="json", data_files=[str(file) for file in train_data_files]
86
+ )["train"]
87
+ valid_data_files = [args.valid_subset]
88
+ dataset_dict["valid"] = load_dataset(
89
+ path="json", data_files=[str(file) for file in valid_data_files]
90
+ )["train"]
91
+
92
+ print(dataset_dict)
93
+
94
+ # model
95
+ tokenizer = BertTokenizer.from_pretrained(args.pretrained_model_name_or_path)
96
+ model = GPT2LMHeadModel.from_pretrained(args.pretrained_model_name_or_path)
97
+
98
+ def encode_with_truncation(examples):
99
+ outputs = tokenizer.__call__(examples['text'],
100
+ truncation=True,
101
+ padding='max_length',
102
+ max_length=args.max_seq_length,
103
+ return_special_tokens_mask=True)
104
+ return outputs
105
+
106
+ def encode_without_truncation(examples):
107
+ outputs = tokenizer.__call__(examples['text'],
108
+ return_special_tokens_mask=True)
109
+ return outputs
110
+
111
+ def group_texts(examples):
112
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
113
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
114
+ if total_length >= args.max_seq_length:
115
+ total_length = (total_length // args.max_seq_length) * args.max_seq_length
116
+
117
+ result = {
118
+ k: [t[i: i + args.max_seq_length] for i in range(0, total_length, args.max_seq_length)]
119
+ for k, t in concatenated_examples.items()
120
+ }
121
+ return result
122
+
123
+ if args.truncate_longer_samples:
124
+ dataset_dict = dataset_dict.map(
125
+ encode_with_truncation,
126
+ batched=True,
127
+ drop_last_batch=True,
128
+ keep_in_memory=False,
129
+ # num_proc=None if platform.system() == 'Windows' else os.cpu_count() // 2,
130
+ num_proc=None,
131
+ )
132
+ dataset_dict.set_format(type="torch", columns=["input_ids", "attention_mask"])
133
+ else:
134
+ dataset_dict = dataset_dict.map(
135
+ encode_without_truncation,
136
+ batched=True,
137
+ drop_last_batch=True,
138
+ keep_in_memory=False,
139
+ # num_proc=None if platform.system() == 'Windows' else os.cpu_count() // 2,
140
+ num_proc=None,
141
+ )
142
+ dataset_dict.set_format(type="torch", columns=["input_ids", "attention_mask"])
143
+
144
+ dataset_dict = dataset_dict.map(
145
+ group_texts,
146
+ batched=True,
147
+ drop_last_batch=True,
148
+ keep_in_memory=False,
149
+ # num_proc=None if platform.system() == 'Windows' else os.cpu_count() // 2,
150
+ num_proc=None,
151
+ )
152
+ dataset_dict.set_format("torch")
153
+
154
+ data_collator = DataCollatorForLanguageModeling(
155
+ tokenizer=tokenizer, mlm=False
156
+ )
157
+
158
+ training_args = TrainingArguments(
159
+ output_dir=args.output_dir,
160
+ overwrite_output_dir=args.overwrite_output_dir,
161
+ evaluation_strategy=args.evaluation_strategy,
162
+ per_device_train_batch_size=args.per_device_train_batch_size,
163
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
164
+ learning_rate=args.learning_rate,
165
+ num_train_epochs=args.num_train_epochs,
166
+ max_steps=args.max_steps,
167
+ lr_scheduler_type=args.lr_scheduler_type,
168
+ warmup_steps=args.warmup_steps,
169
+ logging_steps=args.logging_steps,
170
+ save_steps=args.save_steps,
171
+ save_total_limit=args.save_total_limit,
172
+ no_cuda=args.no_cuda,
173
+ fp16=args.fp16,
174
+ half_precision_backend=args.half_precision_backend,
175
+ # deepspeed=args.deepspeed,
176
+ report_to=args.report_to,
177
+ resume_from_checkpoint=args.resume_from_checkpoint,
178
+ gradient_checkpointing=args.gradient_checkpointing,
179
+ )
180
+
181
+ trainer = Trainer(
182
+ model=model,
183
+ args=training_args,
184
+ data_collator=data_collator,
185
+ train_dataset=dataset_dict["train"],
186
+ )
187
+ train_result = trainer.train()
188
+
189
+ # 保存最好的 checkpoint
190
+ final_save_path = os.path.join(training_args.output_dir, "final")
191
+ trainer.save_model(final_save_path) # Saves the tokenizer too
192
+ # 保存训练指标
193
+ metrics = train_result.metrics
194
+ trainer.log_metrics("train", metrics)
195
+ trainer.save_metrics("train", metrics)
196
+ trainer.save_state()
197
+
198
+ tokenizer.save_pretrained(final_save_path)
199
+ return
200
+
201
+
202
+ if __name__ == '__main__':
203
+ main()
examples/exercises/chinese_porn_novel/3.test_model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ import sys
6
+
7
+ pwd = os.path.abspath(os.path.dirname(__file__))
8
+ sys.path.append(os.path.join(pwd, '../../../'))
9
+
10
+ import torch
11
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
12
+ from transformers.models.bert.tokenization_bert import BertTokenizer
13
+
14
+ from project_settings import project_path
15
+
16
+
17
+ def get_args():
18
+ """
19
+ python3 3.test_model.py \
20
+ --repetition_penalty 1.2 \
21
+ --trained_model_path /data/tianxing/PycharmProjects/Transformers/trained_models/gpt2_chinese_h_novel
22
+
23
+ python3 3.test_model.py \
24
+ --trained_model_path /data/tianxing/PycharmProjects/Transformers/trained_models/gpt2_chinese_h_novel
25
+
26
+ """
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument(
29
+ '--trained_model_path',
30
+ default=(project_path / "pretrained_models/gpt2-chinese-cluecorpussmall").as_posix(),
31
+ type=str,
32
+ )
33
+ parser.add_argument('--device', default='auto', type=str)
34
+
35
+ parser.add_argument('--max_new_tokens', default=512, type=int)
36
+ parser.add_argument('--top_p', default=0.85, type=float)
37
+ parser.add_argument('--temperature', default=0.35, type=float)
38
+ parser.add_argument('--repetition_penalty', default=1.2, type=float)
39
+
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+
44
+ def main():
45
+ args = get_args()
46
+
47
+ if args.device == 'auto':
48
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
49
+ else:
50
+ device = args.device
51
+
52
+ # pretrained model
53
+ tokenizer = BertTokenizer.from_pretrained(args.trained_model_path)
54
+ model = GPT2LMHeadModel.from_pretrained(args.trained_model_path)
55
+
56
+ model.eval()
57
+ model = model.to(device)
58
+
59
+ while True:
60
+ text = input('prefix: ')
61
+
62
+ if text == "Quit":
63
+ break
64
+ text = '{}'.format(text)
65
+ input_ids = tokenizer(text, return_tensors="pt").input_ids
66
+ input_ids = input_ids[:, :-1]
67
+ # print(input_ids)
68
+ # print(type(input_ids))
69
+ input_ids = input_ids.to(device)
70
+
71
+ outputs = model.generate(input_ids,
72
+ max_new_tokens=512,
73
+ do_sample=True,
74
+ top_p=args.top_p,
75
+ temperature=args.temperature,
76
+ repetition_penalty=args.repetition_penalty,
77
+ eos_token_id=tokenizer.sep_token_id,
78
+ pad_token_id=tokenizer.pad_token_id
79
+ )
80
+ rets = tokenizer.batch_decode(outputs)
81
+ output = rets[0].replace(" ", "").replace("[CLS]", "").replace("[SEP]", "")
82
+ print("{}".format(output))
83
+
84
+ return
85
+
86
+
87
+ if __name__ == '__main__':
88
+ main()
examples/exercises/chinese_porn_novel/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ## 预训练 GPT 模型
2
+
3
+ ```text
4
+ 参考链接:
5
+ https://huggingface.co/docs/transformers/model_doc/openai-gpt
6
+ https://huggingface.co/learn/nlp-course/chapter7/6
7
+
8
+ ```
examples/exercises/chinese_porn_novel/run.sh ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # nohup sh run.sh --stage 0 --stop_stage 1 --system_version centos &
4
+ # sh run.sh --stage 0 --stop_stage 1 --system_version windows
5
+ # sh run.sh --stage 0 --stop_stage 0 --system_version centos
6
+ # sh run.sh --stage 2 --stop_stage 2 --system_version centos --checkpoint_name final
7
+ # sh run.sh --stage -1 --stop_stage 1
8
+
9
+ # bitsandbytes
10
+ export LD_LIBRARY_PATH="/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
11
+
12
+ # params
13
+ system_version="windows";
14
+ verbose=true;
15
+ stage=0 # start from 0 if you need to start from data preparation
16
+ stop_stage=5
17
+
18
+ pretrained_model_name=gpt2-chinese-cluecorpussmall
19
+
20
+ train_subset=train.jsonl
21
+ valid_subset=valid.jsonl
22
+
23
+ final_model_name=gpt2_chinese_h_novel
24
+
25
+ checkpoint_name=final
26
+
27
+ # parse options
28
+ while true; do
29
+ [ -z "${1:-}" ] && break; # break if there are no arguments
30
+ case "$1" in
31
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
32
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
33
+ old_value="(eval echo \\$$name)";
34
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
35
+ was_bool=true;
36
+ else
37
+ was_bool=false;
38
+ fi
39
+
40
+ # Set the variable to the right value-- the escaped quotes make it work if
41
+ # the option had spaces, like --cmd "queue.pl -sync y"
42
+ eval "${name}=\"$2\"";
43
+
44
+ # Check that Boolean-valued arguments are really Boolean.
45
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
46
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
47
+ exit 1;
48
+ fi
49
+ shift 2;
50
+ ;;
51
+
52
+ *) break;
53
+ esac
54
+ done
55
+
56
+ $verbose && echo "system_version: ${system_version}"
57
+
58
+ work_dir="$(pwd)"
59
+ file_dir="$(pwd)/file_dir"
60
+ pretrained_models_dir="${work_dir}/../../../pretrained_models";
61
+ serialization_dir="${file_dir}/serialization_dir"
62
+ final_model_dir="${work_dir}/../../../trained_models/${final_model_name}";
63
+
64
+ mkdir -p "${file_dir}"
65
+ mkdir -p "${pretrained_models_dir}"
66
+ mkdir -p "${serialization_dir}"
67
+ mkdir -p "${final_model_dir}"
68
+
69
+
70
+ export PYTHONPATH="${work_dir}/../../.."
71
+
72
+
73
+ if [ $system_version == "windows" ]; then
74
+ alias python3='C:/Users/tianx/PycharmProjects/virtualenv/Transformers/Scripts/python.exe'
75
+ elif [ $system_version == "centos" ]; then
76
+ # conda activate Transformers
77
+ alias python3='/usr/local/miniconda3/envs/Transformers/bin/python3'
78
+ elif [ $system_version == "ubuntu" ]; then
79
+ # conda activate Transformers
80
+ alias python3='/usr/local/miniconda3/envs/Transformers/bin/python3'
81
+ fi
82
+
83
+
84
+ declare -A pretrained_model_dict
85
+ pretrained_model_dict=(
86
+ ["gpt2-chinese-cluecorpussmall"]="https://huggingface.co/uer/gpt2-chinese-cluecorpussmall"
87
+ ["gpt2"]="https://huggingface.co/gpt2"
88
+ ["japanese-gpt2-medium"]="https://huggingface.co/rinna/japanese-gpt2-medium"
89
+
90
+ )
91
+ pretrained_model_dir="${pretrained_models_dir}/${pretrained_model_name}"
92
+
93
+
94
+ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
95
+ $verbose && echo "stage -1: download pretrained model"
96
+ cd "${file_dir}" || exit 1;
97
+
98
+ if [ ! -d "${pretrained_model_dir}" ]; then
99
+ cd "${pretrained_models_dir}" || exit 1;
100
+
101
+ repository_url="${pretrained_model_dict[${pretrained_model_name}]}"
102
+ git clone "${repository_url}"
103
+
104
+ cd "${pretrained_model_dir}" || exit 1;
105
+ rm flax_model.msgpack && rm pytorch_model.bin && rm tf_model.h5
106
+ wget "${repository_url}/resolve/main/pytorch_model.bin"
107
+ fi
108
+ fi
109
+
110
+
111
+ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
112
+ $verbose && echo "stage 0: prepare data"
113
+ cd "${work_dir}" || exit 1;
114
+
115
+ python3 1.prepare_data.py \
116
+ --train_subset "${file_dir}/${train_subset}" \
117
+ --valid_subset "${file_dir}/${valid_subset}" \
118
+
119
+ fi
120
+
121
+
122
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
123
+ $verbose && echo "stage 1: train model"
124
+ cd "${work_dir}" || exit 1;
125
+
126
+ python3 2.train_model.py \
127
+ --train_subset "${file_dir}/${train_subset}" \
128
+ --valid_subset "${file_dir}/${valid_subset}" \
129
+ --pretrained_model_name_or_path "${pretrained_models_dir}/${pretrained_model_name}" \
130
+ --output_dir "${serialization_dir}"
131
+
132
+ fi
133
+
134
+
135
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
136
+ $verbose && echo "stage 2: collect files"
137
+ cd "${work_dir}" || exit 1;
138
+
139
+ cp "${serialization_dir}/${checkpoint_name}/pytorch_model.bin" "${final_model_dir}/pytorch_model.bin"
140
+
141
+ cp "${pretrained_models_dir}/${pretrained_model_name}/config.json" "${final_model_dir}/config.json"
142
+ cp "${pretrained_models_dir}/${pretrained_model_name}/special_tokens_map.json" "${final_model_dir}/special_tokens_map.json"
143
+ cp "${pretrained_models_dir}/${pretrained_model_name}/tokenizer_config.json" "${final_model_dir}/tokenizer_config.json"
144
+ cp "${pretrained_models_dir}/${pretrained_model_name}/vocab.txt" "${final_model_dir}/vocab.txt"
145
+
146
+ fi
examples/exercises/chinese_porn_novel/stop.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ kill -9 `ps -aef | grep 'run.sh' | grep -v grep | awk '{print $2}' | sed 's/\n/ /'`
4
+
5
+ kill -9 `ps -aef | grep 'Transformers/bin/python3' | grep -v grep | awk '{print $2}' | sed 's/\n/ /'`
main.py CHANGED
@@ -117,7 +117,9 @@ def main():
117
  yield output
118
 
119
  model_name_choices = ["trained_models/lib_service_4chan"] \
120
- if platform.system() == "Windows" else ["qgyd2021/lib_service_4chan", "qgyd2021/chinese_chitchat"]
 
 
121
  demo = gr.Interface(
122
  fn=fn_stream,
123
  inputs=[
@@ -133,6 +135,8 @@ def main():
133
  examples=[
134
  ["怎样擦屁股才能擦的干净", 512, 0.75, 0.35, 1.2, "qgyd2021/lib_service_4chan", True],
135
  ["你好", 512, 0.75, 0.35, 1.2, "qgyd2021/chinese_chitchat", True],
 
 
136
  ],
137
  cache_examples=False,
138
  examples_per_page=50,
 
117
  yield output
118
 
119
  model_name_choices = ["trained_models/lib_service_4chan"] \
120
+ if platform.system() == "Windows" else \
121
+ ["qgyd2021/lib_service_4chan", "qgyd2021/chinese_chitchat", "qgyd2021/chinese_porn_novel"]
122
+
123
  demo = gr.Interface(
124
  fn=fn_stream,
125
  inputs=[
 
135
  examples=[
136
  ["怎样擦屁股才能擦的干净", 512, 0.75, 0.35, 1.2, "qgyd2021/lib_service_4chan", True],
137
  ["你好", 512, 0.75, 0.35, 1.2, "qgyd2021/chinese_chitchat", True],
138
+ ["白洁走到床边并脱去内衣, 一双硕大的", 512, 0.75, 0.35, 1.2, "qgyd2021/chinese_porn_novel", False],
139
+ ["男人走进房间, 上床, 压上", 512, 0.75, 0.35, 1.2, "qgyd2021/chinese_porn_novel", False],
140
  ],
141
  cache_examples=False,
142
  examples_per_page=50,