hoang1007's picture
Upload 69 files
44db343
raw
history blame
3.46 kB
from torch.utils.data.dataloader import Sampler
import sys
sys.path.append("..")
from dataset.autocorrect_dataset import SpellCorrectDataset
import numpy as np
from params import RANDOM_SEED, MAXIMUM_TOKENS_PER_BATCH
import copy
from tqdm import tqdm
import time
class RandomBatchSampler(Sampler):
def __init__(self, data: SpellCorrectDataset, batch_size = 1, shuffle = True):
self.data = data
self.seq = list(range(0, len(self.data)))
self.shuffle = shuffle
self.iters = 0
self.batch_size = batch_size
if self.shuffle:
np.random.seed(RANDOM_SEED)
np.random.shuffle(self.seq)
self.seq = [ self.seq[index: index + self.batch_size] \
for index in range(self.iters, len(self.seq), self.batch_size)]
self.default_seq = copy.deepcopy(self.seq)
def __iter__(self):
return iter(self.seq)
def __len__(self):
return len(self.seq)
def load_checkpoints(self, iters = 0):
self.seq = list(range(0, len(self.data)))
if self.shuffle:
np.random.seed(RANDOM_SEED)
np.random.shuffle(self.seq)
self.iters = iters
self.seq = [ self.seq[index: index + self.batch_size] \
for index in range(self.iters, len(self.seq), self.batch_size)]
class BucketBatchSampler(Sampler):
def __init__(self, data: SpellCorrectDataset, shuffle = True):
start = time.time()
self.remained_indies = None
self.data = data
self.shuffle = shuffle
print("Initializing Bucket Batch Sampler From Scratch")
self.data.dataset = sorted(self.data.dataset, key = lambda x: x[2])
token_counts = 0
indies_lists = []
self.seq = []
for index, values in tqdm(enumerate(self.data.dataset)):
if token_counts >= MAXIMUM_TOKENS_PER_BATCH:
self.seq.append(indies_lists)
indies_lists = []
token_counts = 0
indies_lists.append(index)
token_counts += values[2]
if len(indies_lists) != 0 and token_counts != 0:
self.seq.append(indies_lists)
if shuffle:
np.random.seed(RANDOM_SEED)
np.random.shuffle(self.seq)
end = time.time()
print(f"Initialized Bucket Batch Sampler From Scratch: {end - start}")
self.default_seq = copy.deepcopy(self.seq)
def __iter__(self):
return iter(self.seq)
def __len__(self):
return len(self.seq)
def load_checkpoints(self, remained_indies):
start = time.time()
print("Loading Bucket Batch Sampler From Checkpoint")
remained_indies = sorted(remained_indies)
token_counts = 0
indies_lists = []
self.seq = []
for index in tqdm(remained_indies):
values = self.data.dataset[index]
if token_counts >= MAXIMUM_TOKENS_PER_BATCH:
self.seq.append(indies_lists)
indies_lists = []
token_counts = 0
indies_lists.append(index)
token_counts += values[2]
if len(indies_lists) != 0 and token_counts != 0:
self.seq.append(indies_lists)
if self.shuffle:
np.random.seed(RANDOM_SEED)
np.random.shuffle(self.seq)
end = time.time()
print(f"Loaded Bucket Batch Sampler From Checkpoint: {end - start}")