AutoModelForSequenceClassification error "input batch_size (1080) to match target batch_size (2)"

#1
by fish - opened

Thank you for developing a very portable tool!
At the moment I follow the tutorial for fine-tuning-Sequence classification:

https://www.heywhale.com/mw/project/6459d2cf08d33de8e1e8fa05/content

It works well at ESM2, but when I change the model to ESMplusplus_large, the code returns a batch size error.
the code of the loading model:

from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
modelClassification = AutoModelForSequenceClassification.from_pretrained("ESMplusplus_large", trust_remote_code=True, num_labels=2).to("cuda:0")
tokenizer = modelClassification.tokenizer

the batch size info, via changing the source code here, it looks like the batch size of labels and data is both 2, which seems good:
image.png

logits.shape: torch.Size([2, 540, 64])
logits.view(-1, self.vocab_size).shape: torch.Size([1080, 64])
labels.shape: torch.Size([2])
labels.view(-1).shape: torch.Size([2])
logits.shape: torch.Size([2, 540, 64])
logits.view(-1, self.vocab_size).shape: torch.Size([1080, 64])
labels.shape: torch.Size([2])
labels.view(-1).shape: torch.Size([2])

Errors:

ValueError Traceback (most recent call last)
Cell In[11], line 1
----> 1 trainer.train()

File ~/miniforge3/envs/esmIne3/lib/python3.12/site-packages/transformers/trainer.py:2123, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2121 hf_hub_utils.enable_progress_bars()
2122 else:
-> 2123 return inner_training_loop(
2124 args=args,
2125 resume_from_checkpoint=resume_from_checkpoint,
2126 trial=trial,
2127 ignore_keys_for_eval=ignore_keys_for_eval,
2128 )

File ~/miniforge3/envs/esmIne3/lib/python3.12/site-packages/transformers/trainer.py:2481, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2475 context = (
2476 functools.partial(self.accelerator.no_sync, model=model)
2477 if i != len(batch_samples) - 1
2478 else contextlib.nullcontext
2479 )
2480 with context():
-> 2481 tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
2483 if (
2484 args.logging_nan_inf_filter
2485 and not is_torch_xla_available()
2486 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
2487 ):
2488 # if loss is nan or inf simply add the average of previous logged losses
2489 tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File ~/miniforge3/envs/esmIne3/lib/python3.12/site-packages/transformers/trainer.py:3579, in Trainer.training_step(self, model, inputs, num_items_in_batch)
3576 return loss_mb.reduce_mean().detach().to(self.args.device)
3578 with self.compute_loss_context_manager():
-> 3579 loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
3581 del inputs
3582 if (
3583 self.args.torch_empty_cache_steps is not None
3584 and self.state.global_step % self.args.torch_empty_cache_steps == 0
3585 ):

File ~/miniforge3/envs/esmIne3/lib/python3.12/site-packages/transformers/trainer.py:3633, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
3631 loss_kwargs["num_items_in_batch"] = num_items_in_batch
3632 inputs = {**inputs, **loss_kwargs}
-> 3633 outputs = model(**inputs)
3634 # Save past state if it exists
3635 # TODO: this needs to be fixed and made cleaner later.
3636 if self.args.past_index >= 0:

File ~/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()

File ~/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py:193, in DataParallel.forward(self, *inputs, **kwargs)
191 return self.module(*inputs[0], **module_kwargs[0])
192 replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
--> 193 outputs = self.parallel_apply(replicas, inputs, module_kwargs)
194 return self.gather(outputs, self.output_device)

File ~/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py:212, in DataParallel.parallel_apply(self, replicas, inputs, kwargs)
209 def parallel_apply(
210 self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any
211 ) -> List[Any]:
--> 212 return parallel_apply(
213 replicas, inputs, kwargs, self.device_ids[: len(replicas)]
214 )

File ~/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py:126, in parallel_apply(modules, inputs, kwargs_tup, devices)
124 output = results[i]
125 if isinstance(output, ExceptionWrapper):
--> 126 output.reraise()
127 outputs.append(output)
128 return outputs

File ~/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/_utils.py:715, in ExceptionWrapper.reraise(self)
711 except TypeError:
712 # If the exception takes multiple arguments, don't try to
713 # instantiate since we don't know how to
714 raise RuntimeError(msg) from None
--> 715 raise exception

ValueError: Caught ValueError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/root/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
output = module(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/ESMplusplus_large/modeling_esm_plusplus.py", line 676, in forward
output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/ESMplusplus_large/modeling_esm_plusplus.py", line 635, in forward
loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/nn/modules/loss.py", line 1293, in forward
return F.cross_entropy(
^^^^^^^^^^^^^^^^
File "/root/miniforge3/envs/esmIne3/lib/python3.12/site-packages/torch/nn/functional.py", line 3479, in cross_entropy
return torch._C._nn.cross_entropy_loss(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Expected input batch_size (1080) to match target batch_size (2).

Synthyra org

Hi @fish ,

Thank you so much for bringing this to our attention. There was a small bug that was preventing proper functionality. Should be ready to go now. Please let me know if you have any other questions.

Some checks I ran to confirm

Model correctly initializes new weights for the classifier but not pretrained model:

Some weights of ESMplusplusForSequenceClassification were not initialized from the model checkpoint at Synthyra/ESMplusplus_small and are newly initialized: ['classifier.0.bias', 'classifier.0.weight', 'classifier.2.bias', 'classifier.2.weight', 'classifier.3.bias', 'classifier.3.weight']

Uses the correct number of labels (in this case 4):

  (classifier): Sequential(
    (0): Linear(in_features=1920, out_features=3840, bias=True)
    (1): GELU(approximate='none')
    (2): LayerNorm((3840,), eps=1e-05, elementwise_affine=True)
    (3): Linear(in_features=3840, out_features=4, bias=True)
  )

Output has all components:

ESMplusplusOutput(loss=tensor(32., grad_fn=<NllLossBackward0>), logits=tensor([[64.,  0.,  0.,  0.],
        [64.,  0.,  0.,  0.]], grad_fn=<AddmmBackward0>), last_hidden_state=tensor([[[-0.0015, -0.0021,  0.0062,  ...,  0.0058, -0.0023, -0.0101],
         [-0.0222, -0.0142,  0.0519,  ...,  0.0017,  0.0189, -0.0007],
         [-0.0461, -0.0351,  0.0073,  ...,  0.0291,  0.0173, -0.0182],
         ...,
         [-0.0224,  0.0053,  0.0110,  ...,  0.0105, -0.0066, -0.0260],
         [-0.0209, -0.0269,  0.0143,  ...,  0.0138,  0.0152, -0.0154],
         [-0.0305, -0.0258,  0.0168,  ...,  0.0107,  0.0223, -0.0148]],

        [[ 0.0021, -0.0066,  0.0084,  ...,  0.0074, -0.0047, -0.0088],
         [-0.0192, -0.0184,  0.0412,  ...,  0.0367,  0.0079,  0.0022],
         [-0.0296, -0.0512,  0.0186,  ...,  0.0244, -0.0138,  0.0069],
         ...,
         [-0.0417, -0.0303,  0.0023,  ...,  0.0016, -0.0020, -0.0031],
         [ 0.0071, -0.0096, -0.0061,  ...,  0.0206, -0.0221, -0.0289],
         [-0.0129, -0.0130,  0.0148,  ...,  0.0025, -0.0064, -0.0149]]],
       grad_fn=<NativeLayerNormBackward0>), hidden_states=())

To run for yourself:

import torch
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("Synthyra/ESMplusplus_large", trust_remote_code=True, num_labels=4)
tokenizer = model.tokenizer

sequences = ['MPRTEIN', 'MSEQWENCE']
tokenized = tokenizer(sequences, padding=True, return_tensors='pt')

labels = torch.randint(0, 4, (2,)).long() # random labels for the two sequences
tokenized['labels'] = labels
# if you are on device, make sure to move the tokenized to the device
# tokenized = {k: v.to(device) for k, v in tokenized.items()}
output = model(**tokenized)
print(output)

Thank you so much for your prompt reply! It works well for me now!
Thank you very much for your work, it has greatly accelerated its use in the scientific community! I will quote your work in my article, thank you again.

fish changed discussion status to closed

Sign up or log in to comment