File size: 4,771 Bytes
508087f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
"""
Sample from a trained model
"""
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from model import GPTConfig, GPT
from tqdm import tqdm
import random
import numpy as np
from transformers import AutoTokenizer
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
import argparse
import itertools
import random
parser = argparse.ArgumentParser()
parser.add_argument("--init_from", type=str, default="resume", help="Directory of raw data & output files")
parser.add_argument("--out_path", type=str, required=True)
parser.add_argument("--num_samples", type=int, required=False, default=100000)
parser.add_argument("--max_new_tokens", type=int, required=True, help="number of tokens generated in each sample")
parser.add_argument("--strategy",type=str, required=False,default='top_k',help="should be in ['greedy_search', 'sampling', 'top_k', 'beam_search']")
parser.add_argument("--temperature",type=float, required=False,default=1.0,help="1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions")
parser.add_argument("--top_k",type=int, required=False,default=20,help="retain only the top_k most likely tokens, clamp others to have 0 probability")
parser.add_argument("--ckpt_path",type=str, required=True,help="path to a checkpoint/model")
parser.add_argument("--tokenizer_path",type=str, required=True,help="path to a tokenizer directory")
parser.add_argument("--start",type=str, required=False,default="<|endoftext|>")
parser.add_argument("--repetition_penalty",type=float, required=False,default=1.0)
parser.add_argument("--shuffle_token", action='store_true', help="Enable shuffling of tokens before decoding")
args = parser.parse_args()
init_from = args.init_from
out_path = args.out_path
num_samples = args.num_samples
max_new_tokens = args.max_new_tokens
strategy = args.strategy
temperature = args.temperature
top_k = args.top_k
ckpt_path = args.ckpt_path
tokenizer_path = args.tokenizer_path
start = args.start
repetition_penalty = args.repetition_penalty
# -----------------------------------------------------------------------------
seed = random.randint(1,6666)
# seed = 1337
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'float32'
# dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------
# Load tokenizer & ckpt_path
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# -----------------------------------------------------------------------------
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# model
if init_from == 'resume':
# init from a model saved in a specific directory
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
elif init_from.startswith('gpt2'):
# init from a given GPT-2 model
model = GPT.from_pretrained(init_from, dict(dropout=0.0))
model.eval()
model.to(device)
if compile:
model = torch.compile(model) # requires PyTorch 2.0 (optional)
# look for the meta pickle in case it is available in the dataset folder
load_meta = False
encode = tokenizer.encode
decode = tokenizer.decode
start_ids = encode("".join(start))
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
with open(out_path, 'a') as f:
with torch.no_grad():
with ctx:
for k in tqdm(range(num_samples), desc="Generating samples"):
token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)[0].tolist()
# Shuffle tokens if --shuffle_token is specified
if args.shuffle_token:
random.shuffle(token_sequence)
y = decode(token_sequence) + '\n'
f.write(y)
f.flush() |