boris commited on
Commit
3f0364c
·
1 Parent(s): fad333f

feat: adjust seq2seq script for dalle

Browse files
seq2seq/requirements.txt CHANGED
@@ -3,3 +3,4 @@ jax>=0.2.8
3
  jaxlib>=0.1.59
4
  flax>=0.3.4
5
  optax>=0.0.8
 
 
3
  jaxlib>=0.1.59
4
  flax>=0.3.4
5
  optax>=0.0.8
6
+ tensorboard
seq2seq/run_seq2seq_flax.py CHANGED
@@ -40,6 +40,7 @@ import optax
40
  import transformers
41
  from filelock import FileLock
42
  from flax import jax_utils, traverse_util
 
43
  from flax.jax_utils import unreplicate
44
  from flax.training import train_state
45
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
@@ -49,12 +50,15 @@ from transformers import (
49
  AutoConfig,
50
  AutoTokenizer,
51
  FlaxAutoModelForSeq2SeqLM,
 
52
  HfArgumentParser,
53
  TrainingArguments,
54
  is_tensorboard_available,
55
  )
 
56
  from transformers.file_utils import is_offline_mode
57
 
 
58
 
59
  logger = logging.getLogger(__name__)
60
 
@@ -73,6 +77,13 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
73
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
74
 
75
 
 
 
 
 
 
 
 
76
  @dataclass
77
  class ModelArguments:
78
  """
@@ -80,7 +91,7 @@ class ModelArguments:
80
  """
81
 
82
  model_name_or_path: Optional[str] = field(
83
- default=None,
84
  metadata={
85
  "help": "The model checkpoint for weights initialization."
86
  "Don't set if you want to train a model from scratch."
@@ -124,12 +135,12 @@ class DataTrainingArguments:
124
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
125
  )
126
  text_column: Optional[str] = field(
127
- default=None,
128
  metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
129
  )
130
- summary_column: Optional[str] = field(
131
- default=None,
132
- metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
133
  )
134
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
135
  validation_file: Optional[str] = field(
@@ -148,7 +159,7 @@ class DataTrainingArguments:
148
  },
149
  )
150
  max_target_length: Optional[int] = field(
151
- default=128,
152
  metadata={
153
  "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
154
  "than this will be truncated, sequences shorter will be padded."
@@ -219,21 +230,6 @@ class DataTrainingArguments:
219
  self.val_max_target_length = self.max_target_length
220
 
221
 
222
- summarization_name_mapping = {
223
- "amazon_reviews_multi": ("review_body", "review_title"),
224
- "big_patent": ("description", "abstract"),
225
- "cnn_dailymail": ("article", "highlights"),
226
- "orange_sum": ("text", "summary"),
227
- "pn_summary": ("article", "summary"),
228
- "psc": ("extract_text", "summary_text"),
229
- "samsum": ("dialogue", "summary"),
230
- "thaisum": ("body", "summary"),
231
- "xglue": ("news_body", "news_title"),
232
- "xsum": ("document", "summary"),
233
- "wiki_summary": ("article", "highlights"),
234
- }
235
-
236
-
237
  class TrainState(train_state.TrainState):
238
  dropout_rng: jnp.ndarray
239
 
@@ -241,6 +237,45 @@ class TrainState(train_state.TrainState):
241
  return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
242
 
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
245
  """
246
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
@@ -315,6 +350,15 @@ def main():
315
  f"Output directory ({training_args.output_dir}) already exists and is not empty."
316
  "Use --overwrite_output_dir to overcome."
317
  )
 
 
 
 
 
 
 
 
 
318
 
319
  # Make one log on every process with the configuration for debugging.
320
  logging.basicConfig(
@@ -338,64 +382,41 @@ def main():
338
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
339
  # (the dataset will be downloaded automatically from the datasets Hub).
340
  #
341
- # For CSV/JSON files this script will use the first column for the full texts and the second column for the
342
- # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
343
- #
344
- if data_args.dataset_name is not None:
345
- # Downloading and loading a dataset from the hub.
346
- dataset = load_dataset(
347
- data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
348
- )
349
- else:
350
- data_files = {}
351
- if data_args.train_file is not None:
352
- data_files["train"] = data_args.train_file
353
- extension = data_args.train_file.split(".")[-1]
354
- if data_args.validation_file is not None:
355
- data_files["validation"] = data_args.validation_file
356
- extension = data_args.validation_file.split(".")[-1]
357
- if data_args.test_file is not None:
358
- data_files["test"] = data_args.test_file
359
- extension = data_args.test_file.split(".")[-1]
360
- dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
361
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
362
  # https://huggingface.co/docs/datasets/loading_datasets.html.
363
 
364
  # Load pretrained model and tokenizer
 
 
 
 
 
 
365
 
366
- if model_args.config_name:
367
- config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
368
- elif model_args.model_name_or_path:
369
- config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
370
- else:
371
- config = CONFIG_MAPPING[model_args.model_type]()
372
- logger.warning("You are instantiating a new config instance from scratch.")
373
 
374
- if model_args.tokenizer_name:
375
- tokenizer = AutoTokenizer.from_pretrained(
376
- model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
377
- )
378
- elif model_args.model_name_or_path:
379
- tokenizer = AutoTokenizer.from_pretrained(
380
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
381
- )
382
- else:
383
- raise ValueError(
384
- "You are instantiating a new tokenizer from scratch. This is not supported by this script."
385
- "You can do it from another script, save it, and load it from here, using --tokenizer_name."
386
- )
387
 
388
- if model_args.model_name_or_path:
389
- model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
390
- model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
391
- )
392
- else:
393
- model = FlaxAutoModelForSeq2SeqLM.from_config(
394
- config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
395
- )
396
 
397
- if model.config.decoder_start_token_id is None:
398
- raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
 
 
399
 
400
  prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
401
 
@@ -412,23 +433,8 @@ def main():
412
  return
413
 
414
  # Get the column names for input/target.
415
- dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
416
- if data_args.text_column is None:
417
- text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
418
- else:
419
- text_column = data_args.text_column
420
- if text_column not in column_names:
421
- raise ValueError(
422
- f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
423
- )
424
- if data_args.summary_column is None:
425
- summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
426
- else:
427
- summary_column = data_args.summary_column
428
- if summary_column not in column_names:
429
- raise ValueError(
430
- f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
431
- )
432
 
433
  # Temporarily set max_target_length for training.
434
  max_target_length = data_args.max_target_length
@@ -442,26 +448,26 @@ def main():
442
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
443
  def preprocess_function(examples):
444
  inputs = examples[text_column]
445
- targets = examples[summary_column]
446
  inputs = [prefix + inp for inp in inputs]
447
  model_inputs = tokenizer(
448
  inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
449
  )
450
 
451
- # Setup the tokenizer for targets
452
- with tokenizer.as_target_tokenizer():
453
- labels = tokenizer(
454
- targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
455
- )
456
 
457
- model_inputs["labels"] = labels["input_ids"]
 
 
458
  decoder_input_ids = shift_tokens_right_fn(
459
  jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
460
  )
 
461
  model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
462
 
463
  # We need decoder_attention_mask so we can ignore pad tokens from loss
464
- model_inputs["decoder_attention_mask"] = labels["attention_mask"]
 
465
 
466
  return model_inputs
467
 
 
40
  import transformers
41
  from filelock import FileLock
42
  from flax import jax_utils, traverse_util
43
+ import flax.linen as nn
44
  from flax.jax_utils import unreplicate
45
  from flax.training import train_state
46
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
 
50
  AutoConfig,
51
  AutoTokenizer,
52
  FlaxAutoModelForSeq2SeqLM,
53
+ FlaxBartForConditionalGeneration,
54
  HfArgumentParser,
55
  TrainingArguments,
56
  is_tensorboard_available,
57
  )
58
+ from transformers.models.bart.modeling_flax_bart import *
59
  from transformers.file_utils import is_offline_mode
60
 
61
+ import wandb
62
 
63
  logger = logging.getLogger(__name__)
64
 
 
77
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
78
 
79
 
80
+ # Model hyperparameters, for convenience
81
+ OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
82
+ OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
83
+ BOS_TOKEN_ID = 16384
84
+ BASE_MODEL = 'facebook/bart-large-cnn'
85
+
86
+
87
  @dataclass
88
  class ModelArguments:
89
  """
 
91
  """
92
 
93
  model_name_or_path: Optional[str] = field(
94
+ default=BASE_MODEL,
95
  metadata={
96
  "help": "The model checkpoint for weights initialization."
97
  "Don't set if you want to train a model from scratch."
 
135
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
136
  )
137
  text_column: Optional[str] = field(
138
+ default='caption',
139
  metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
140
  )
141
+ encoding_column: Optional[str] = field(
142
+ default='encoding',
143
+ metadata={"help": "The name of the column in the datasets containing the image encodings."},
144
  )
145
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
146
  validation_file: Optional[str] = field(
 
159
  },
160
  )
161
  max_target_length: Optional[int] = field(
162
+ default=OUTPUT_LENGTH,
163
  metadata={
164
  "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
165
  "than this will be truncated, sequences shorter will be padded."
 
230
  self.val_max_target_length = self.max_target_length
231
 
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  class TrainState(train_state.TrainState):
234
  dropout_rng: jnp.ndarray
235
 
 
237
  return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
238
 
239
 
240
+ class CustomFlaxBartModule(FlaxBartModule):
241
+ def setup(self):
242
+ # we keep shared to easily load pre-trained weights
243
+ self.shared = nn.Embed(
244
+ self.config.vocab_size,
245
+ self.config.d_model,
246
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
247
+ dtype=self.dtype,
248
+ )
249
+ # a separate embedding is used for the decoder
250
+ self.decoder_embed = nn.Embed(
251
+ OUTPUT_VOCAB_SIZE,
252
+ self.config.d_model,
253
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
254
+ dtype=self.dtype,
255
+ )
256
+ self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
257
+
258
+ # the decoder has a different config
259
+ decoder_config = BartConfig(self.config.to_dict())
260
+ decoder_config.max_position_embeddings = OUTPUT_LENGTH
261
+ decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
262
+ self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
263
+
264
+ class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
265
+ def setup(self):
266
+ self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
267
+ self.lm_head = nn.Dense(
268
+ OUTPUT_VOCAB_SIZE,
269
+ use_bias=False,
270
+ dtype=self.dtype,
271
+ kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
272
+ )
273
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
274
+
275
+ class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
276
+ module_class = CustomFlaxBartForConditionalGenerationModule
277
+
278
+
279
  def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
280
  """
281
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
 
350
  f"Output directory ({training_args.output_dir}) already exists and is not empty."
351
  "Use --overwrite_output_dir to overcome."
352
  )
353
+
354
+ # Set up wandb run
355
+ wandb.init(
356
+ sync_tensorboard=True,
357
+ entity='wandb',
358
+ project='hf-flax-dalle-mini',
359
+ job_type='Seq2SeqVQGAN',
360
+ config=parser.parse_args()
361
+ )
362
 
363
  # Make one log on every process with the configuration for debugging.
364
  logging.basicConfig(
 
382
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
383
  # (the dataset will be downloaded automatically from the datasets Hub).
384
  #
385
+ data_files = {}
386
+ if data_args.train_file is not None:
387
+ data_files["train"] = data_args.train_file
388
+ if data_args.validation_file is not None:
389
+ data_files["validation"] = data_args.validation_file
390
+ if data_args.test_file is not None:
391
+ data_files["test"] = data_args.test_file
392
+ dataset = load_dataset"csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
 
 
 
 
 
 
 
 
 
 
 
 
393
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
394
  # https://huggingface.co/docs/datasets/loading_datasets.html.
395
 
396
  # Load pretrained model and tokenizer
397
+ base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
398
+ model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
399
+ )
400
+ tokenizer = AutoTokenizer.from_pretrained(
401
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
402
+ )
403
 
404
+ # Set up our new model config
405
+ config = BartConfig.from_pretrained(model_args.model_name_or_path)
406
+ config.tie_word_embeddings = False
407
+ config.decoder_start_token_id = BOS_TOKEN_ID
408
+ config.bos_token_id = BOS_TOKEN_ID # should not be used
409
+ config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
410
+ config.eos_token_id = None # prevents generation from stopping until we reach max_length
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
+ # Create a custom model and initialize it randomly
414
+ model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
 
 
 
 
 
 
415
 
416
+ # Use pre-trained weights for encoder
417
+ model.params['model']['encoder'] = base_model.params['model']['encoder']
418
+ model.params['model']['shared'] = base_model.params['model']['shared']
419
+ del base_model
420
 
421
  prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
422
 
 
433
  return
434
 
435
  # Get the column names for input/target.
436
+ text_column = data_args.text_column
437
+ encoding_column = data_args.encoding_column
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
439
  # Temporarily set max_target_length for training.
440
  max_target_length = data_args.max_target_length
 
448
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
449
  def preprocess_function(examples):
450
  inputs = examples[text_column]
 
451
  inputs = [prefix + inp for inp in inputs]
452
  model_inputs = tokenizer(
453
  inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
454
  )
455
 
456
+ # set up targets
457
+ model_inputs["labels"] = [eval(indices) for indices in examples['encoding']]
 
 
 
458
 
459
+ # TODO: if data processing prevents correct compilation, we will:
460
+ # - have data saved in JSONL (to avoid `eval` which is needed here to convert string "[2]" to list[int])
461
+ # - use below `shift_tokens_right_fn`
462
  decoder_input_ids = shift_tokens_right_fn(
463
  jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
464
  )
465
+
466
  model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
467
 
468
  # We need decoder_attention_mask so we can ignore pad tokens from loss
469
+ # TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
470
+ #model_inputs["decoder_attention_mask"] = labels["attention_mask"]
471
 
472
  return model_inputs
473