winglian commited on
Commit
2db9436
·
1 Parent(s): 120e7df

casts the prepared data to int16 (doesn't help with training memory)

Browse files
Files changed (2) hide show
  1. scripts/finetune.py +1 -2
  2. src/axolotl/datasets.py +12 -5
scripts/finetune.py CHANGED
@@ -14,7 +14,6 @@ import transformers
14
  import yaml
15
  from attrdict import AttrDefault
16
  from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
17
- from huggingface_hub.hf_api import DatasetInfo
18
  from torch import nn
19
  from transformers import (
20
  AutoModelForCausalLM,
@@ -169,7 +168,7 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
169
 
170
  if cfg.load_4bit:
171
  # Scales to half
172
- print('Fitting 4bit scales and zeros to half')
173
  for n, m in model.named_modules():
174
  if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)):
175
  if hasattr(m, "is_v1_model") and m.is_v1_model:
 
14
  import yaml
15
  from attrdict import AttrDefault
16
  from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
 
17
  from torch import nn
18
  from transformers import (
19
  AutoModelForCausalLM,
 
168
 
169
  if cfg.load_4bit:
170
  # Scales to half
171
+ logging.info('Fitting 4bit scales and zeros to half')
172
  for n, m in model.named_modules():
173
  if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)):
174
  if hasattr(m, "is_v1_model") and m.is_v1_model:
src/axolotl/datasets.py CHANGED
@@ -30,7 +30,6 @@ class TokenizedPromptDataset(IterableDataset):
30
  except InvalidDataException:
31
  pass
32
 
33
-
34
  # TODO this isn't the best since it can't interleave datasets
35
  class ConstantLengthDataset(IterableDataset):
36
  """
@@ -40,7 +39,6 @@ class ConstantLengthDataset(IterableDataset):
40
  dataset (dataset.Dataset): Dataset with text files.
41
  seq_length (int): Length of token sequences to return.
42
  """
43
-
44
  def __init__(
45
  self,
46
  tokenizer,
@@ -52,6 +50,15 @@ class ConstantLengthDataset(IterableDataset):
52
  self.datasets: List[IterableDataset] = datasets
53
  self.seq_length = seq_length
54
 
 
 
 
 
 
 
 
 
 
55
  def __iter__(self):
56
  buffer = {"input_ids": [], "attention_mask": [], "labels": []}
57
  buffer_len = 0
@@ -105,11 +112,11 @@ class ConstantLengthDataset(IterableDataset):
105
  attention_mask.append(1)
106
  labels.append(self.concat_token_id)
107
 
108
- input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
109
  attention_mask_with_concat = torch.tensor(
110
- attention_mask, dtype=torch.long
111
  )
112
- labels_with_concat = torch.tensor(labels, dtype=torch.long)
113
 
114
  buffer["input_ids"].append(input_ids_with_concat)
115
  buffer["attention_mask"].append(attention_mask_with_concat)
 
30
  except InvalidDataException:
31
  pass
32
 
 
33
  # TODO this isn't the best since it can't interleave datasets
34
  class ConstantLengthDataset(IterableDataset):
35
  """
 
39
  dataset (dataset.Dataset): Dataset with text files.
40
  seq_length (int): Length of token sequences to return.
41
  """
 
42
  def __init__(
43
  self,
44
  tokenizer,
 
50
  self.datasets: List[IterableDataset] = datasets
51
  self.seq_length = seq_length
52
 
53
+ vocab_size = len(tokenizer.get_vocab())
54
+
55
+ if vocab_size <= torch.iinfo(torch.int16).max:
56
+ self.tokens_dtype = torch.int16
57
+ elif vocab_size <= torch.iinfo(torch.int32).max:
58
+ self.tokens_dtype = torch.int32
59
+ else:
60
+ self.tokens_dtype = torch.int64
61
+
62
  def __iter__(self):
63
  buffer = {"input_ids": [], "attention_mask": [], "labels": []}
64
  buffer_len = 0
 
112
  attention_mask.append(1)
113
  labels.append(self.concat_token_id)
114
 
115
+ input_ids_with_concat = torch.tensor(input_ids, dtype=self.tokens_dtype)
116
  attention_mask_with_concat = torch.tensor(
117
+ attention_mask, dtype=self.tokens_dtype
118
  )
119
+ labels_with_concat = torch.tensor(labels, dtype=self.tokens_dtype)
120
 
121
  buffer["input_ids"].append(input_ids_with_concat)
122
  buffer["attention_mask"].append(attention_mask_with_concat)