Spaces:
Runtime error
Runtime error
from torch.utils.data.dataloader import Sampler | |
import sys | |
sys.path.append("..") | |
from dataset.autocorrect_dataset import SpellCorrectDataset | |
import numpy as np | |
from params import RANDOM_SEED, MAXIMUM_TOKENS_PER_BATCH | |
import copy | |
from tqdm import tqdm | |
import time | |
class RandomBatchSampler(Sampler): | |
def __init__(self, data: SpellCorrectDataset, batch_size = 1, shuffle = True): | |
self.data = data | |
self.seq = list(range(0, len(self.data))) | |
self.shuffle = shuffle | |
self.iters = 0 | |
self.batch_size = batch_size | |
if self.shuffle: | |
np.random.seed(RANDOM_SEED) | |
np.random.shuffle(self.seq) | |
self.seq = [ self.seq[index: index + self.batch_size] \ | |
for index in range(self.iters, len(self.seq), self.batch_size)] | |
self.default_seq = copy.deepcopy(self.seq) | |
def __iter__(self): | |
return iter(self.seq) | |
def __len__(self): | |
return len(self.seq) | |
def load_checkpoints(self, iters = 0): | |
self.seq = list(range(0, len(self.data))) | |
if self.shuffle: | |
np.random.seed(RANDOM_SEED) | |
np.random.shuffle(self.seq) | |
self.iters = iters | |
self.seq = [ self.seq[index: index + self.batch_size] \ | |
for index in range(self.iters, len(self.seq), self.batch_size)] | |
class BucketBatchSampler(Sampler): | |
def __init__(self, data: SpellCorrectDataset, shuffle = True): | |
start = time.time() | |
self.remained_indies = None | |
self.data = data | |
self.shuffle = shuffle | |
print("Initializing Bucket Batch Sampler From Scratch") | |
self.data.dataset = sorted(self.data.dataset, key = lambda x: x[2]) | |
token_counts = 0 | |
indies_lists = [] | |
self.seq = [] | |
for index, values in tqdm(enumerate(self.data.dataset)): | |
if token_counts >= MAXIMUM_TOKENS_PER_BATCH: | |
self.seq.append(indies_lists) | |
indies_lists = [] | |
token_counts = 0 | |
indies_lists.append(index) | |
token_counts += values[2] | |
if len(indies_lists) != 0 and token_counts != 0: | |
self.seq.append(indies_lists) | |
if shuffle: | |
np.random.seed(RANDOM_SEED) | |
np.random.shuffle(self.seq) | |
end = time.time() | |
print(f"Initialized Bucket Batch Sampler From Scratch: {end - start}") | |
self.default_seq = copy.deepcopy(self.seq) | |
def __iter__(self): | |
return iter(self.seq) | |
def __len__(self): | |
return len(self.seq) | |
def load_checkpoints(self, remained_indies): | |
start = time.time() | |
print("Loading Bucket Batch Sampler From Checkpoint") | |
remained_indies = sorted(remained_indies) | |
token_counts = 0 | |
indies_lists = [] | |
self.seq = [] | |
for index in tqdm(remained_indies): | |
values = self.data.dataset[index] | |
if token_counts >= MAXIMUM_TOKENS_PER_BATCH: | |
self.seq.append(indies_lists) | |
indies_lists = [] | |
token_counts = 0 | |
indies_lists.append(index) | |
token_counts += values[2] | |
if len(indies_lists) != 0 and token_counts != 0: | |
self.seq.append(indies_lists) | |
if self.shuffle: | |
np.random.seed(RANDOM_SEED) | |
np.random.shuffle(self.seq) | |
end = time.time() | |
print(f"Loaded Bucket Batch Sampler From Checkpoint: {end - start}") | |