""" |
Sample from a trained model |
""" |
import os |
import pickle |
from contextlib import nullcontext |
import torch |
import tiktoken |
from model import GPTConfig, GPT |
from tqdm import tqdm |
import random |
import numpy as np |
from transformers import AutoTokenizer |
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode |
import argparse |
import itertools |
import random |
parser = argparse.ArgumentParser() |
parser.add_argument("--init_from", type=str, default="resume", help="Directory of raw data & output files") |
parser.add_argument("--out_path", type=str, required=True) |
parser.add_argument("--num_samples", type=int, required=False, default=100000) |
parser.add_argument("--max_new_tokens", type=int, required=True, help="number of tokens generated in each sample") |
parser.add_argument("--strategy",type=str, required=False,default='top_k',help="should be in ['greedy_search', 'sampling', 'top_k', 'beam_search']") |
parser.add_argument("--temperature",type=float, required=False,default=1.0,help="1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions") |
parser.add_argument("--top_k",type=int, required=False,default=20,help="retain only the top_k most likely tokens, clamp others to have 0 probability") |
parser.add_argument("--ckpt_path",type=str, required=True,help="path to a checkpoint/model") |
parser.add_argument("--tokenizer_path",type=str, required=True,help="path to a tokenizer directory") |
parser.add_argument("--start",type=str, required=False,default="<|endoftext|>") |
parser.add_argument("--repetition_penalty",type=float, required=False,default=1.0) |
parser.add_argument("--shuffle_token", action='store_true', help="Enable shuffling of tokens before decoding") |
args = parser.parse_args() |
init_from = args.init_from |
out_path = args.out_path |
num_samples = args.num_samples |
max_new_tokens = args.max_new_tokens |
strategy = args.strategy |
temperature = args.temperature |
top_k = args.top_k |
ckpt_path = args.ckpt_path |
tokenizer_path = args.tokenizer_path |
start = args.start |
repetition_penalty = args.repetition_penalty |
seed = random.randint(1,6666) |
device = 'cuda' |
dtype = 'float32' |
compile = False |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
torch.manual_seed(seed) |
torch.cuda.manual_seed(seed) |
torch.backends.cuda.matmul.allow_tf32 = True |
torch.backends.cudnn.allow_tf32 = True |
device_type = 'cuda' if 'cuda' in device else 'cpu' |
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] |
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
if init_from == 'resume': |
checkpoint = torch.load(ckpt_path, map_location=device) |
gptconf = GPTConfig(**checkpoint['model_args']) |
model = GPT(gptconf) |
state_dict = checkpoint['model'] |
unwanted_prefix = '_orig_mod.' |
for k,v in list(state_dict.items()): |
if k.startswith(unwanted_prefix): |
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
model.load_state_dict(state_dict) |
elif init_from.startswith('gpt2'): |
model = GPT.from_pretrained(init_from, dict(dropout=0.0)) |
model.eval() |
model.to(device) |
if compile: |
model = torch.compile(model) |
load_meta = False |
encode = tokenizer.encode |
decode = tokenizer.decode |
start_ids = encode("".join(start)) |
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) |
with open(out_path, 'a') as f: |
with torch.no_grad(): |
with ctx: |
for k in tqdm(range(num_samples), desc="Generating samples"): |
token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)[0].tolist() |
if args.shuffle_token: |
random.shuffle(token_sequence) |
y = decode(token_sequence) + '\n' |
f.write(y) |
f.flush() |