MasaakiKotera
commited on
Upload sampling.py with huggingface_hub
Browse files- sampling.py +54 -16
sampling.py
CHANGED
@@ -23,6 +23,7 @@ parser.add_argument("--out_path", type=str, required=True)
|
|
23 |
parser.add_argument("--num_samples", type=int, required=False, default=100000)
|
24 |
parser.add_argument("--max_new_tokens", type=int, required=True, help="number of tokens generated in each sample")
|
25 |
parser.add_argument("--strategy",type=str, required=False,default='top_k',help="should be in ['greedy_search', 'sampling', 'top_k', 'beam_search']")
|
|
|
26 |
parser.add_argument("--temperature",type=float, required=False,default=1.0,help="1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions")
|
27 |
parser.add_argument("--top_k",type=int, required=False,default=20,help="retain only the top_k most likely tokens, clamp others to have 0 probability")
|
28 |
parser.add_argument("--ckpt_path",type=str, required=True,help="path to a checkpoint/model")
|
@@ -30,6 +31,7 @@ parser.add_argument("--tokenizer_path",type=str, required=True,help="path to a t
|
|
30 |
parser.add_argument("--start",type=str, required=False,default="<|endoftext|>")
|
31 |
parser.add_argument("--repetition_penalty",type=float, required=False,default=1.0)
|
32 |
parser.add_argument("--shuffle_token", action='store_true', help="Enable shuffling of tokens before decoding")
|
|
|
33 |
|
34 |
args = parser.parse_args()
|
35 |
init_from = args.init_from
|
@@ -37,17 +39,20 @@ out_path = args.out_path
|
|
37 |
num_samples = args.num_samples
|
38 |
max_new_tokens = args.max_new_tokens
|
39 |
strategy = args.strategy
|
|
|
|
|
40 |
temperature = args.temperature
|
41 |
top_k = args.top_k
|
42 |
ckpt_path = args.ckpt_path
|
43 |
tokenizer_path = args.tokenizer_path
|
44 |
start = args.start
|
45 |
repetition_penalty = args.repetition_penalty
|
|
|
|
|
46 |
|
47 |
# -----------------------------------------------------------------------------
|
48 |
seed = random.randint(1,6666)
|
49 |
-
|
50 |
-
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
|
51 |
dtype = 'float32'
|
52 |
# dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
|
53 |
compile = False # use PyTorch 2.0 to compile the model to be faster
|
@@ -91,20 +96,53 @@ load_meta = False
|
|
91 |
encode = tokenizer.encode
|
92 |
decode = tokenizer.decode
|
93 |
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
|
|
|
|
|
|
97 |
|
98 |
-
with open(out_path, 'a') as f:
|
99 |
-
with torch.no_grad():
|
100 |
-
with ctx:
|
101 |
-
for k in tqdm(range(num_samples), desc="Generating samples"):
|
102 |
-
token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)[0].tolist()
|
103 |
-
|
104 |
-
# Shuffle tokens if --shuffle_token is specified
|
105 |
-
if args.shuffle_token:
|
106 |
-
random.shuffle(token_sequence)
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
23 |
parser.add_argument("--num_samples", type=int, required=False, default=100000)
|
24 |
parser.add_argument("--max_new_tokens", type=int, required=True, help="number of tokens generated in each sample")
|
25 |
parser.add_argument("--strategy",type=str, required=False,default='top_k',help="should be in ['greedy_search', 'sampling', 'top_k', 'beam_search']")
|
26 |
+
parser.add_argument("--beam_size",type=int, required=False,default=3,help="beam size for beam search")
|
27 |
parser.add_argument("--temperature",type=float, required=False,default=1.0,help="1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions")
|
28 |
parser.add_argument("--top_k",type=int, required=False,default=20,help="retain only the top_k most likely tokens, clamp others to have 0 probability")
|
29 |
parser.add_argument("--ckpt_path",type=str, required=True,help="path to a checkpoint/model")
|
|
|
31 |
parser.add_argument("--start",type=str, required=False,default="<|endoftext|>")
|
32 |
parser.add_argument("--repetition_penalty",type=float, required=False,default=1.0)
|
33 |
parser.add_argument("--shuffle_token", action='store_true', help="Enable shuffling of tokens before decoding")
|
34 |
+
parser.add_argument("--fasta", action='store_true', default=True, help="Enable writing output in FASTA format")
|
35 |
|
36 |
args = parser.parse_args()
|
37 |
init_from = args.init_from
|
|
|
39 |
num_samples = args.num_samples
|
40 |
max_new_tokens = args.max_new_tokens
|
41 |
strategy = args.strategy
|
42 |
+
assert strategy in ['greedy_search', 'sampling', 'top_k', 'beam_search']
|
43 |
+
beam_size = args.beam_size
|
44 |
temperature = args.temperature
|
45 |
top_k = args.top_k
|
46 |
ckpt_path = args.ckpt_path
|
47 |
tokenizer_path = args.tokenizer_path
|
48 |
start = args.start
|
49 |
repetition_penalty = args.repetition_penalty
|
50 |
+
fasta = args.fasta
|
51 |
+
|
52 |
|
53 |
# -----------------------------------------------------------------------------
|
54 |
seed = random.randint(1,6666)
|
55 |
+
device = 'cuda'
|
|
|
56 |
dtype = 'float32'
|
57 |
# dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
|
58 |
compile = False # use PyTorch 2.0 to compile the model to be faster
|
|
|
96 |
encode = tokenizer.encode
|
97 |
decode = tokenizer.decode
|
98 |
|
99 |
+
fasta_out_path = os.path.splitext(out_path)[0] + ".fasta" if fasta else None
|
100 |
+
|
101 |
+
if strategy in["sampling", "top_k"]:
|
102 |
+
start_ids = encode("".join(start))
|
103 |
+
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
|
104 |
+
|
105 |
+
|
106 |
+
with open(out_path, 'a') as f:
|
107 |
+
with open(fasta_out_path, 'a') if fasta else nullcontext() as fasta_f:
|
108 |
+
with torch.no_grad():
|
109 |
+
with ctx:
|
110 |
+
for k in tqdm(range(num_samples), desc="Generating samples"):
|
111 |
+
token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)[0].tolist()
|
112 |
+
|
113 |
+
# Shuffle tokens if --shuffle_token is specified
|
114 |
+
if args.shuffle_token:
|
115 |
+
random.shuffle(token_sequence)
|
116 |
+
|
117 |
+
y = decode(token_sequence).replace(' ', '')
|
118 |
+
# y = decode(token_sequence).replace('\n', '').replace(' ', '') + '\n'
|
119 |
+
f.write(y)
|
120 |
+
f.flush()
|
121 |
+
|
122 |
+
|
123 |
+
if fasta:
|
124 |
+
fasta_entry = f">sample_{k}\n{y.replace(' ', '')}\n"
|
125 |
+
fasta_f.write(fasta_entry.strip() + '\n')
|
126 |
+
fasta_f.flush()
|
127 |
+
|
128 |
+
|
129 |
+
elif strategy in ["beam_search", "greedy_search"]:
|
130 |
+
with open(out_path, 'a') as f:
|
131 |
+
with open(fasta_out_path, 'a') if fasta else nullcontext() as fasta_f:
|
132 |
+
with torch.no_grad():
|
133 |
+
with ctx:
|
134 |
+
start = '<|endoftext|>'
|
135 |
+
start_ids = encode(start)
|
136 |
+
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
|
137 |
+
|
138 |
+
token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, beam_size=beam_size)[0].tolist()
|
139 |
|
140 |
+
y = decode(token_sequence).replace(' ', '')
|
141 |
+
f.write(y)
|
142 |
+
f.flush()
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
+
if fasta:
|
146 |
+
fasta_entry = f">sample_{k}\n{y.replace(' ', '')}\n"
|
147 |
+
fasta_f.write(fasta_entry.strip() + '\n')
|
148 |
+
fasta_f.flush()
|