RuntimeError: FlashAttention only support fp16 and bf16 data type with Flash attention2
#33
by
liougehooa
- opened
When I train with QLora + Flash attention, it has this error. But if I train with Lora + Flash attention, it doesn't.
Here's the code snippets with QLora:
base_model_id = "microsoft/Phi-3.5-vision-instruct"
# Initialize model
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = transformers.AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype= dtype, ## torch.bfloat16,
trust_remote_code=True,
quantization_config=bnb_config,
use_flash_attention_2=True
).to(device)
processor = transformers.AutoProcessor.from_pretrained(base_model_id, trust_remote_code=True)
training_args = transformers.TrainingArguments(
num_train_epochs=40, # Number of training epochs
per_device_train_batch_size=batch_size, # Batch size for training
per_device_eval_batch_size=batch_size, # Batch size for evaluation
gradient_accumulation_steps=2, # Number of steps to accumulate gradients before updating
gradient_checkpointing=True, # Enable gradient checkpointing to save memory
do_eval=True, # Perform evaluation during training
save_total_limit=2, # Limit the total number of saved checkpoints
evaluation_strategy="steps", # Evaluation strategy to use (here, at each specified number of steps)
save_strategy="steps", # Save checkpoints at each specified number of steps
save_steps=10, # Number of steps between each checkpoint save
eval_steps=10, # Number of steps between each evaluation
max_grad_norm=1, # Maximum gradient norm for clipping
warmup_ratio=0.1, # Warmup ratio for learning rate schedule
weight_decay=0.001, # Regularization technique to prevent overfitting
# fp16=True, # Enable mixed precision training with fp16 (enable it if Ampere architecture is unavailable)
bf16=True, # Enable mixed precision training with bf16
logging_steps=10, # Number of steps between each log
output_dir="outputs", # Directory to save the model outputs and checkpoints
optim="adamw_torch", # Optimizer to use (AdamW with PyTorch)
learning_rate=5e-5, # Learning rate for the optimizer
lr_scheduler_type="linear", # Learning rate scheduler type: constant
load_best_model_at_end=True, # Load the best model found during training at the end
metric_for_best_model="rouge", # Metric used to determine the best model
greater_is_better=True, # Indicates if a higher metric score is better
push_to_hub=False, # Whether to push the model to Hugging Face Hub
run_name="phi-3-5-vision-finetuning", # Name of the run for experiment tracking
report_to="wandb" # For experiment tracking (login to Weights & Biases needed)
)
class CustomTrainer(transformers.Trainer):
def get_train_dataloader(self):
...
def get_eval_dataloader(self, eval_dataset=None):
...
def compute_metrics(eval_pred):
logits, labels = eval_pred
predicted = logits.argmax(-1)
labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)
rouge_scores = rouge.compute(predictions=decoded_predictions, references=decoded_labels)
rouge1_score = rouge_scores["rouge1"]
return {"rouge": rouge1_score}
# Ensure the model is in training mode
peft_model.train()
trainer = CustomTrainer(
model=peft_model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
)
peft_model.config.use_cache = False
trainer.train()
The error:
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py:51, in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax)
49 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
50 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 51 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
52 q,
53 k,
54 v,
55 None,
56 alibi_slopes,
57 dropout_p,
58 softmax_scale,
59 causal,
60 window_size[0],
61 window_size[1],
62 return_softmax,
63 None,
64 )
65 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
RuntimeError: FlashAttention only support fp16 and bf16 data type
There's similiar problem here:
https://huggingface.co/microsoft/Phi-3-small-8k-instruct/discussions/11
When I added the conversion the code in CLIPAttentionFA2(modeling_phi3_v.py) for a test:
query_states = self.q_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
## added conversion code before apply to flash_attn_func. You can manually set target_dtype to your target dtype, eg, torch.bfloat16
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
The above can work.
lib versions:
Flash attention2: 2.5.8
transformer:4.45.2
peft: 0.11.1
bitsandbytes: 0.44.1
The whole error
RuntimeError Traceback (most recent call last)
----> 1 trainer.train()
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/trainer.py:2123, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2121 hf_hub_utils.enable_progress_bars()
2122 else:
-> 2123 return inner_training_loop(
2124 args=args,
2125 resume_from_checkpoint=resume_from_checkpoint,
2126 trial=trial,
2127 ignore_keys_for_eval=ignore_keys_for_eval,
2128 )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/trainer.py:2481, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2475 context = (
2476 functools.partial(self.accelerator.no_sync, model=model)
2477 if i == len(batch_samples) - 1
2478 else contextlib.nullcontext
2479 )
2480 with context():
-> 2481 tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
2483 if (
2484 args.logging_nan_inf_filter
2485 and not is_torch_xla_available()
2486 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
2487 ):
2488 # if loss is nan or inf simply add the average of previous logged losses
2489 tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/trainer.py:3579, in Trainer.training_step(self, model, inputs, num_items_in_batch)
3576 return loss_mb.reduce_mean().detach().to(self.args.device)
3578 with self.compute_loss_context_manager():
-> 3579 loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
3581 del inputs
3582 if (
3583 self.args.torch_empty_cache_steps is not None
3584 and self.state.global_step % self.args.torch_empty_cache_steps == 0
3585 ):
Cell In[38], line 24, in CustomTrainer.compute_loss(self, model, inputs, num_items_in_batch, return_outputs)
23 def compute_loss(self, model, inputs, num_items_in_batch=0, return_outputs=False):
---> 24 outputs = model(**inputs)
25 loss = outputs.loss if isinstance(outputs, dict) else outputs[0]
26 return (loss, outputs) if return_outputs else loss
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/utils/operations.py:823, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
822 def forward(*args, **kwargs):
--> 823 return model_forward(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/utils/operations.py:811, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
810 def __call__(self, *args, **kwargs):
--> 811 return convert_to_fp32(self.model_forward(*args, **kwargs))
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
13 @functools.wraps(func)
14 def decorate_autocast(*args, **kwargs):
15 with autocast_instance:
---> 16 return func(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/peft/peft_model.py:1577, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
1575 with self._enable_peft_forward_hooks(**kwargs):
1576 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1577 return self.base_model(
1578 input_ids=input_ids,
1579 attention_mask=attention_mask,
1580 inputs_embeds=inputs_embeds,
1581 labels=labels,
1582 output_attentions=output_attentions,
1583 output_hidden_states=output_hidden_states,
1584 return_dict=return_dict,
1585 **kwargs,
1586 )
1588 batch_size = _get_batch_size(input_ids, inputs_embeds)
1589 if attention_mask is not None:
1590 # concat prompt attention mask
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/peft/tuners/tuners_utils.py:188, in BaseTuner.forward(self, *args, **kwargs)
187 def forward(self, *args: Any, **kwargs: Any):
--> 188 return self.model.forward(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3.5-vision-instruct/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py:1603, in Phi3VForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, image_sizes, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1600 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1602 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1603 outputs = self.model(
1604 input_ids=input_ids,
1605 attention_mask=attention_mask,
1606 position_ids=position_ids,
1607 past_key_values=past_key_values,
1608 inputs_embeds=inputs_embeds,
1609 pixel_values=pixel_values,
1610 image_sizes=image_sizes,
1611 use_cache=use_cache,
1612 output_attentions=output_attentions,
1613 output_hidden_states=output_hidden_states,
1614 return_dict=return_dict,
1615 )
1617 hidden_states = outputs[0]
1618 logits = self.lm_head(hidden_states)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3.5-vision-instruct/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py:1431, in Phi3VModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, image_sizes, use_cache, output_attentions, output_hidden_states, return_dict)
1429 if pixel_values is not None and image_sizes is not None:
1430 assert self.vision_embed_tokens is not None, "Vision embedding layer is not defined"
-> 1431 inputs_embeds = self.vision_embed_tokens(input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
1432 else:
1433 inputs_embeds = self.embed_tokens(input_ids)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3.5-vision-instruct/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py:237, in Phi3ImageEmbedding.forward(self, input_ids, pixel_values, image_sizes)
235 num_images, num_crops, c, h, w = pixel_values.shape
236 assert c == 3 and h == w == 336
--> 237 img_features = self.get_img_features(pixel_values.flatten(0, 1)).reshape(
238 num_images, num_crops, -1, self.image_dim_out
239 )
240 image_features_proj = self.hd_feature_transform(img_features, image_sizes)
241 hidden_states = hidden_states.index_put(
242 positions, image_features_proj, accumulate=False
243 )
File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3.5-vision-instruct/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py:212, in Phi3ImageEmbedding.get_img_features(self, img_embeds)
209 LAYER_IDX = self.layer_idx
210 TYPE_FEATURE = self.type_feature
--> 212 img_processor_output = self.img_processor(img_embeds, output_hidden_states=True)
213 img_feature = img_processor_output.hidden_states[LAYER_IDX]
215 if TYPE_FEATURE == "patch":
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py:1171, in CLIPVisionModel.forward(self, pixel_values, output_attentions, output_hidden_states, interpolate_pos_encoding, return_dict)
1147 r"""
1148 Returns:
1149
(...)
1167 >>> pooled_output = outputs.pooler_output # pooled CLS states
1168 ```"""
1169 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-> 1171 return self.vision_model(
1172 pixel_values=pixel_values,
1173 output_attentions=output_attentions,
1174 output_hidden_states=output_hidden_states,
1175 return_dict=return_dict,
1176 interpolate_pos_encoding=interpolate_pos_encoding,
1177 )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py:1097, in CLIPVisionTransformer.forward(self, pixel_values, output_attentions, output_hidden_states, return_dict, interpolate_pos_encoding)
1094 hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
1095 hidden_states = self.pre_layrnorm(hidden_states)
-> 1097 encoder_outputs = self.encoder(
1098 inputs_embeds=hidden_states,
1099 output_attentions=output_attentions,
1100 output_hidden_states=output_hidden_states,
1101 return_dict=return_dict,
1102 )
1104 last_hidden_state = encoder_outputs[0]
1105 pooled_output = last_hidden_state[:, 0, :]
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py:869, in CLIPEncoder.forward(self, inputs_embeds, attention_mask, causal_attention_mask, output_attentions, output_hidden_states, return_dict)
867 encoder_states = encoder_states + (hidden_states,)
868 if self.gradient_checkpointing and self.training:
--> 869 layer_outputs = self._gradient_checkpointing_func(
870 encoder_layer.__call__,
871 hidden_states,
872 attention_mask,
873 causal_attention_mask,
874 output_attentions,
875 )
876 else:
877 layer_outputs = encoder_layer(
878 hidden_states,
879 attention_mask,
880 causal_attention_mask,
881 output_attentions=output_attentions,
882 )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/_compile.py:24, in _disable_dynamo.<locals>.inner(*args, **kwargs)
20 @functools.wraps(fn)
21 def inner(*args, **kwargs):
22 import torch._dynamo
---> 24 return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:451, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
449 prior = set_eval_frame(callback)
450 try:
--> 451 return fn(*args, **kwargs)
452 finally:
453 set_eval_frame(prior)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/_dynamo/external_utils.py:36, in wrap_inline.<locals>.inner(*args, **kwargs)
34 @functools.wraps(fn)
35 def inner(*args, **kwargs):
---> 36 return fn(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/utils/checkpoint.py:487, in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, *args, **kwargs)
482 if context_fn is not noop_context_fn or debug is not False:
483 raise ValueError(
484 "Passing `context_fn` or `debug` is only supported when "
485 "use_reentrant=False."
486 )
--> 487 return CheckpointFunction.apply(function, preserve, *args)
488 else:
489 gen = _checkpoint_without_reentrant_generator(
490 function, preserve, context_fn, determinism_check, debug, *args, **kwargs
491 )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, **kwargs)
595 if not torch._C._are_functorch_transforms_active():
596 # See NOTE: [functorch vjp and autograd interaction]
597 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598 return super().apply(*args, **kwargs) # type: ignore[misc]
600 if not is_setup_ctx_defined:
601 raise RuntimeError(
602 "In order to use an autograd.Function with functorch transforms "
603 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
604 "staticmethod. For more details, please see "
605 " https://pytorch.org/docs/master/notes/extending.func.html"
606 )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/utils/checkpoint.py:262, in CheckpointFunction.forward(ctx, run_function, preserve_rng_state, *args)
259 ctx.save_for_backward(*tensor_inputs)
261 with torch.no_grad():
--> 262 outputs = run_function(*args)
263 return outputs
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py:608, in CLIPEncoderLayer.forward(self, hidden_states, attention_mask, causal_attention_mask, output_attentions)
605 residual = hidden_states
607 hidden_states = self.layer_norm1(hidden_states)
--> 608 hidden_states, attn_weights = self.self_attn(
609 hidden_states=hidden_states,
610 attention_mask=attention_mask,
611 causal_attention_mask=causal_attention_mask,
612 output_attentions=output_attentions,
613 )
614 hidden_states = residual + hidden_states
616 residual = hidden_states
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3.5-vision-instruct/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py:105, in CLIPAttentionFA2.forward(self, hidden_states, attention_mask, causal_attention_mask, output_attentions)
102 key_states = self.k_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
103 value_states = self.v_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
--> 105 attn_output = flash_attn_func(
106 query_states,
107 key_states,
108 value_states,
109 dropout_p=self.dropout if self.training else 0.0,
110 softmax_scale=self.scale,
111 causal=False,
112 ).reshape(bsz, tgt_len, embed_dim)
114 attn_output = self.out_proj(attn_output)
115 return attn_output, None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py:831, in flash_attn_func(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs)
771 def flash_attn_func(
772 q,
773 k,
(...)
781 return_attn_probs=False,
782 ):
783 """dropout_p should be set to 0.0 during evaluation
784 Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
785 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
(...)
829 pattern (negative means that location was dropped, nonnegative means it was kept).
830 """
--> 831 return FlashAttnFunc.apply(
832 q,
833 k,
834 v,
835 dropout_p,
836 softmax_scale,
837 causal,
838 window_size,
839 alibi_slopes,
840 deterministic,
841 return_attn_probs,
842 )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, **kwargs)
595 if not torch._C._are_functorch_transforms_active():
596 # See NOTE: [functorch vjp and autograd interaction]
597 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598 return super().apply(*args, **kwargs) # type: ignore[misc]
600 if not is_setup_ctx_defined:
601 raise RuntimeError(
602 "In order to use an autograd.Function with functorch transforms "
603 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
604 "staticmethod. For more details, please see "
605 " https://pytorch.org/docs/master/notes/extending.func.html"
606 )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py:511, in FlashAttnFunc.forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax)
509 if softmax_scale is None:
510 softmax_scale = q.shape[-1] ** (-0.5)
--> 511 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
512 q,
513 k,
514 v,
515 dropout_p,
516 softmax_scale,
517 causal=causal,
518 window_size=window_size,
519 alibi_slopes=alibi_slopes,
520 return_softmax=return_softmax and dropout_p > 0,
521 )
522 ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
523 ctx.dropout_p = dropout_p
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py:51, in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax)
49 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
50 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 51 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
52 q,
53 k,
54 v,
55 None,
56 alibi_slopes,
57 dropout_p,
58 softmax_scale,
59 causal,
60 window_size[0],
61 window_size[1],
62 return_softmax,
63 None,
64 )
65 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
RuntimeError: FlashAttention only support fp16 and bf16 data type