winglian commited on
Commit
f733d0f
·
unverified ·
1 Parent(s): 008505c

disable eval using multipack for now (#437)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +22 -22
src/axolotl/utils/trainer.py CHANGED
@@ -14,7 +14,7 @@ import bitsandbytes as bnb
14
  import numpy as np
15
  import torch.cuda
16
  import transformers
17
- from datasets import Dataset, set_caching_enabled
18
  from torch import nn
19
  from torch.optim.lr_scheduler import OneCycleLR
20
  from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
@@ -188,27 +188,27 @@ class AxolotlTrainer(Trainer):
188
  )
189
  return super().get_train_dataloader()
190
 
191
- def get_eval_dataloader(
192
- self, eval_dataset: Optional[Dataset] = None
193
- ) -> Union[DataLoader, MultipackDistributedDataloader]:
194
- if self.args.sample_packing:
195
- eval_dataset = (
196
- eval_dataset if eval_dataset is not None else self.eval_dataset
197
- )
198
- eval_sampler = self._get_eval_sampler(eval_dataset)
199
- return self.accelerator.prepare(
200
- MultipackDistributedDataloader(
201
- eval_dataset,
202
- batch_size=self.args.eval_batch_size,
203
- seq_max_length=self.args.max_seq_length,
204
- collate_fn=self.data_collator,
205
- sampler=eval_sampler,
206
- packing_efficiency_estimate=self.args.sample_packing_efficiency,
207
- sample_packing_seq_len_multiplier=self.args.eval_batch_size,
208
- device_count=int(os.environ.get("WORLD_SIZE", 1)),
209
- )
210
- )
211
- return super().get_eval_dataloader(eval_dataset)
212
 
213
  def compute_loss(self, model, inputs, return_outputs=False):
214
  # use one's weighted cross entropy loss calc
 
14
  import numpy as np
15
  import torch.cuda
16
  import transformers
17
+ from datasets import set_caching_enabled
18
  from torch import nn
19
  from torch.optim.lr_scheduler import OneCycleLR
20
  from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
 
188
  )
189
  return super().get_train_dataloader()
190
 
191
+ # def get_eval_dataloader(
192
+ # self, eval_dataset: Optional[Dataset] = None
193
+ # ) -> Union[DataLoader, MultipackDistributedDataloader]:
194
+ # if self.args.sample_packing:
195
+ # eval_dataset = (
196
+ # eval_dataset if eval_dataset is not None else self.eval_dataset
197
+ # )
198
+ # eval_sampler = self._get_eval_sampler(eval_dataset)
199
+ # return self.accelerator.prepare(
200
+ # MultipackDistributedDataloader(
201
+ # eval_dataset,
202
+ # batch_size=self.args.eval_batch_size,
203
+ # seq_max_length=self.args.max_seq_length,
204
+ # collate_fn=self.data_collator,
205
+ # sampler=eval_sampler,
206
+ # packing_efficiency_estimate=self.args.sample_packing_efficiency,
207
+ # sample_packing_seq_len_multiplier=self.args.eval_batch_size,
208
+ # device_count=int(os.environ.get("WORLD_SIZE", 1)),
209
+ # )
210
+ # )
211
+ # return super().get_eval_dataloader(eval_dataset)
212
 
213
  def compute_loss(self, model, inputs, return_outputs=False):
214
  # use one's weighted cross entropy loss calc