StopIteration: Caught StopIteration in replica 0 on device 0.

#57
by seregadgl - opened

Hi. I run this code in kaggle and get an error. Tell me how to fix this error. Thank you!

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from datasets import load_dataset
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SentenceTransformer("jinaai/jina-embeddings-v3", device=device, trust_remote_code=True, model_kwargs={'default_task': 'text-matching'})

train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train")
"""
Dataset({
features: ['sentence1', 'sentence2', 'label'],
num_rows: 942069
})
"""

trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,

)

trainer.train()

StopIteration: Caught StopIteration in replica 0 on device 0.
Original Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
output = module(*input, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py", line 688, in forward
input = module(input, **module_kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/jinaai/jina-embeddings-v3/fa78e35d523dcda8d3b5212c7487cf70a4b277da/custom_st.py", line 143, in forward
output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
File "/root/.cache/huggingface/modules/transformers_modules/jinaai/xlm-roberta-flash-implementation/12700ba4972d9e900313a85ae855f5a76fb9500e/modeling_lora.py", line 357, in forward
return self.roberta(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/jinaai/xlm-roberta-flash-implementation/12700ba4972d9e900313a85ae855f5a76fb9500e/modeling_xlm_roberta.py", line 684, in forward
hidden_states = self.embeddings(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/jinaai/xlm-roberta-flash-implementation/12700ba4972d9e900313a85ae855f5a76fb9500e/embedding.py", line 55, in forward
embedding_dtype = next(self.word_embeddings.parameters()).dtype
StopIteration

Hi @seregadgl , the code snippet works fine on my machine. Can you provide more details such as versions of torch, sentence-transformers, flash-attn so that I can reproduce the error?

Hi @seregadgl , the code snippet works fine on my machine. Can you provide more details such as versions of torch, sentence-transformers, flash-attn so that I can reproduce the error?

I run this code in kaggle, gpu t4 or p100

torch Version: 2.4.0
sentence-transformers Version: 3.2.1
flash-attn Version: 2.6.3
triton Version: 3.1.0


StopIteration Traceback (most recent call last)
Cell In[5], line 29
14 args = SentenceTransformerTrainingArguments(
15 output_dir='/kaggle/working/',
16 fp16=True, # Set to False if you get an error that your GPU can't run on FP16
17 bf16=False, # Set to True if you have a GPU that supports BF16
18 )
21 trainer = SentenceTransformerTrainer(
22 model=model,
23 train_dataset=train_dataset,
(...)
26
27 )
---> 29 trainer.train()

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2052, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2050 hf_hub_utils.enable_progress_bars()
2051 else:
-> 2052 return inner_training_loop(
2053 args=args,
2054 resume_from_checkpoint=resume_from_checkpoint,
2055 trial=trial,
2056 ignore_keys_for_eval=ignore_keys_for_eval,
2057 )

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2388, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2385 self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
2387 with self.accelerator.accumulate(model):
-> 2388 tr_loss_step = self.training_step(model, inputs)
2390 if (
2391 args.logging_nan_inf_filter
2392 and not is_torch_xla_available()
2393 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
2394 ):
2395 # if loss is nan or inf simply add the average of previous logged losses
2396 tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:3485, in Trainer.training_step(self, model, inputs)
3482 return loss_mb.reduce_mean().detach().to(self.args.device)
3484 with self.compute_loss_context_manager():
-> 3485 loss = self.compute_loss(model, inputs)
3487 del inputs
3488 if (
3489 self.args.torch_empty_cache_steps is not None
3490 and self.state.global_step % self.args.torch_empty_cache_steps == 0
3491 ):

File /opt/conda/lib/python3.10/site-packages/sentence_transformers/trainer.py:348, in SentenceTransformerTrainer.compute_loss(self, model, inputs, return_outputs)
341 if (
342 model == self.model_wrapped
343 and model != self.model # Only if the model is wrapped
344 and hasattr(loss_fn, "model") # Only if the loss stores the model
345 and loss_fn.model != model # Only if the wrapped model is not already stored
346 ):
347 loss_fn = self.override_model_in_loss(loss_fn, model)
--> 348 loss = loss_fn(features, labels)
349 if return_outputs:
350 # During prediction/evaluation, compute_loss will be called with return_outputs=True.
351 # However, Sentence Transformer losses do not return outputs, so we return an empty dictionary.
352 # This does not result in any problems, as the SentenceTransformerTrainingArguments sets
353 # prediction_loss_only=True which means that the output is not used.
354 return loss, {}

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None

File /opt/conda/lib/python3.10/site-packages/sentence_transformers/losses/CoSENTLoss.py:81, in CoSENTLoss.forward(self, sentence_features, labels)
80 def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
---> 81 embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
83 scores = self.similarity_fct(embeddings[0], embeddings[1])
84 scores = scores * self.scale

File /opt/conda/lib/python3.10/site-packages/sentence_transformers/losses/CoSENTLoss.py:81, in (.0)
80 def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
---> 81 embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
83 scores = self.similarity_fct(embeddings[0], embeddings[1])
84 scores = scores * self.scale

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:186, in DataParallel.forward(self, *inputs, **kwargs)
184 return self.module(*inputs[0], **module_kwargs[0])
185 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 186 outputs = self.parallel_apply(replicas, inputs, module_kwargs)
187 return self.gather(outputs, self.output_device)

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:201, in DataParallel.parallel_apply(self, replicas, inputs, kwargs)
200 def parallel_apply(self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any) -> List[Any]:
--> 201 return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py:108, in parallel_apply(modules, inputs, kwargs_tup, devices)
106 output = results[i]
107 if isinstance(output, ExceptionWrapper):
--> 108 output.reraise()
109 outputs.append(output)
110 return outputs

File /opt/conda/lib/python3.10/site-packages/torch/_utils.py:706, in ExceptionWrapper.reraise(self)
702 except TypeError:
703 # If the exception takes multiple arguments, don't try to
704 # instantiate since we don't know how to
705 raise RuntimeError(msg) from None
--> 706 raise exception

StopIteration: Caught StopIteration in replica 0 on device 0.
Original Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
output = module(*input, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/sentence_transformers/SentenceTransformer.py", line 688, in forward
input = module(input, **module_kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/jinaai/jina-embeddings-v3/fa78e35d523dcda8d3b5212c7487cf70a4b277da/custom_st.py", line 143, in forward
output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
File "/root/.cache/huggingface/modules/transformers_modules/jinaai/xlm-roberta-flash-implementation/12700ba4972d9e900313a85ae855f5a76fb9500e/modeling_lora.py", line 357, in forward
return self.roberta(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/jinaai/xlm-roberta-flash-implementation/12700ba4972d9e900313a85ae855f5a76fb9500e/modeling_xlm_roberta.py", line 684, in forward
hidden_states = self.embeddings(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/jinaai/xlm-roberta-flash-implementation/12700ba4972d9e900313a85ae855f5a76fb9500e/embedding.py", line 55, in forward
embedding_dtype = next(self.word_embeddings.parameters()).dtype
StopIteration

Jina AI org

Thanks for more detailed info. Can you also tell me which cuda version you use?

Thanks for more detailed info. Can you also tell me which cuda version you use?

Yes, of course. Version cuda 12.3.
Maybe there is a possibility to test your simplest working code, maybe I am doing something wrong, although it is very strange.

Sign up or log in to comment