File size: 3,462 Bytes
44db343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from torch.utils.data.dataloader import Sampler
import sys
sys.path.append("..")
from dataset.autocorrect_dataset import SpellCorrectDataset
import numpy as np
from params import RANDOM_SEED, MAXIMUM_TOKENS_PER_BATCH
import copy
from tqdm import tqdm
import time

class RandomBatchSampler(Sampler):
    def __init__(self, data: SpellCorrectDataset, batch_size = 1, shuffle = True):
        self.data = data
        self.seq = list(range(0, len(self.data)))
        self.shuffle = shuffle
        self.iters = 0
        self.batch_size = batch_size
        if self.shuffle:
            np.random.seed(RANDOM_SEED)
            np.random.shuffle(self.seq)
        self.seq = [ self.seq[index: index + self.batch_size] \
            for index in range(self.iters, len(self.seq), self.batch_size)]
        self.default_seq = copy.deepcopy(self.seq)

    def __iter__(self):
        return iter(self.seq)

    def __len__(self):
        return len(self.seq)
    
    def load_checkpoints(self, iters = 0):
        self.seq = list(range(0, len(self.data)))
        if self.shuffle: 
            np.random.seed(RANDOM_SEED)
            np.random.shuffle(self.seq)
        self.iters = iters
        self.seq = [ self.seq[index: index + self.batch_size] \
            for index in range(self.iters, len(self.seq), self.batch_size)]

class BucketBatchSampler(Sampler):
    def __init__(self, data: SpellCorrectDataset, shuffle = True):
        start = time.time()
        self.remained_indies = None
        self.data = data
        self.shuffle = shuffle
        print("Initializing Bucket Batch Sampler From Scratch")
        self.data.dataset = sorted(self.data.dataset, key = lambda x: x[2])
        token_counts = 0
        indies_lists = []
        self.seq = []
        for index, values in tqdm(enumerate(self.data.dataset)):
            if token_counts >= MAXIMUM_TOKENS_PER_BATCH:
                self.seq.append(indies_lists)
                indies_lists = []
                token_counts = 0
            indies_lists.append(index)
            token_counts += values[2]
        if len(indies_lists) != 0 and token_counts != 0:
            self.seq.append(indies_lists)

        if shuffle:
            np.random.seed(RANDOM_SEED)
            np.random.shuffle(self.seq)
        end = time.time()
        print(f"Initialized Bucket Batch Sampler From Scratch: {end - start}")
        self.default_seq = copy.deepcopy(self.seq)
        
    def __iter__(self):
        return iter(self.seq)

    def __len__(self):
        return len(self.seq)

    def load_checkpoints(self, remained_indies):
        start = time.time()
        print("Loading Bucket Batch Sampler From Checkpoint")
        remained_indies = sorted(remained_indies)
        token_counts = 0
        indies_lists = []
        self.seq = []
        for index in tqdm(remained_indies):
            values = self.data.dataset[index]
            if token_counts >= MAXIMUM_TOKENS_PER_BATCH:
                self.seq.append(indies_lists)
                indies_lists = []
                token_counts = 0
            indies_lists.append(index)
            token_counts += values[2]

        if len(indies_lists) != 0 and token_counts != 0:
            self.seq.append(indies_lists)
        
        if self.shuffle:
            np.random.seed(RANDOM_SEED)
            np.random.shuffle(self.seq)
        end = time.time()
        print(f"Loaded Bucket Batch Sampler From Checkpoint: {end - start}")