KeyError in triton implementation
I'm loading in the triton implementation of the model using a custom device map and trying to generate an output as follows (to be clear, I have no issues with the torch implementation):
torch_dtype = torch.bfloat16
config = AutoConfig.from_pretrained(
'mosaicml/mpt-7b',
trust_remote_code=True
)
config.attn_config['attn_impl'] = 'triton'
config.update({"max_seq_len": max_len})
config.update({"torch_dtype": torch_dtype})
with open('MPT_device_map.pkl', 'rb') as f:
dm = pickle.load(f)
model = AutoModelForCausalLM.from_pretrained(
'mosaicml/mpt-7b-instruct',
torch_dtype=torch_dtype,
trust_remote_code=True,
device_map=dm,
config=config,
local_files_only=True
)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(text, return_tensors="pt", padding=True).input_ids.to(device)
streamer = TextStreamer(tokenizer)
with torch.inference_mode():
generate_ids = model.generate(inputs, **params, streamer=streamer)
generate_ids = generate_ids[:,inputs[0].shape[-1]:]
output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True)[0]
And I'm getting the following error:
```
/usr/bin/ld: skipping incompatible /usr/lib/libcuda.so when searching for -lcuda
KeyError Traceback (most recent call last)
File :21, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)
KeyError: ('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-d962222789c30252d492a16cca3bf467-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', True, 128, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False)))
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
Cell In[41], line 8
4 with torch.inference_mode():
5 # res = nlp(text, max_new_tokens=mnt, min_new_tokens=1, return_full_text=False)
6 # inputs = input_map(inputs)
7 st = time.time()
----> 8 generate_ids = model.generate(inputs, **params, streamer=streamer)
9 # generate_ids = model.module.generate(**inputs, **params, streamer=streamer)
10 tt = time.time() - st
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.call..decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py:1565, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
1557 input_ids, model_kwargs = self._expand_inputs_for_generation(
1558 input_ids=input_ids,
1559 expand_size=generation_config.num_return_sequences,
1560 is_encoder_decoder=self.config.is_encoder_decoder,
1561 **model_kwargs,
1562 )
1564 # 13. run sample
-> 1565 return self.sample(
1566 input_ids,
1567 logits_processor=logits_processor,
1568 logits_warper=logits_warper,
1569 stopping_criteria=stopping_criteria,
1570 pad_token_id=generation_config.pad_token_id,
1571 eos_token_id=generation_config.eos_token_id,
1572 output_scores=generation_config.output_scores,
1573 return_dict_in_generate=generation_config.return_dict_in_generate,
1574 synced_gpus=synced_gpus,
1575 streamer=streamer,
1576 **model_kwargs,
1577 )
1579 elif is_beam_gen_mode:
1580 if generation_config.num_return_sequences > generation_config.num_beams:
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py:2612, in GenerationMixin.sample(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
2609 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2611 # forward pass to get next token
-> 2612 outputs = self(
2613 **model_inputs,
2614 return_dict=True,
2615 output_attentions=output_attentions,
2616 output_hidden_states=output_hidden_states,
2617 )
2619 if synced_gpus and this_peer_finished:
2620 continue # don't waste resources running the code we don't need
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/modeling_mpt.py:237, in MPTForCausalLM.forward(self, input_ids, past_key_values, attention_mask, prefix_mask, sequence_id, labels, return_dict, output_attentions, output_hidden_states, use_cache)
235 return_dict = return_dict if return_dict is not None else self.config.return_dict
236 use_cache = use_cache if use_cache is not None else self.config.use_cache
--> 237 outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
238 logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
239 if self.logit_scale is not None:
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/modeling_mpt.py:183, in MPTModel.forward(self, input_ids, past_key_values, attention_mask, prefix_mask, sequence_id, return_dict, output_attentions, output_hidden_states, use_cache)
181 all_hidden_states = all_hidden_states + (x,)
182 past_key_value = past_key_values[b_idx] if past_key_values is not None else None
--> 183 (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
184 if past_key_values is not None:
185 past_key_values[b_idx] = past_key_value
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/blocks.py:36, in MPTBlock.forward(self, x, past_key_value, attn_bias, attention_mask, is_causal)
34 def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35 a = self.norm_1(x)
---> 36 (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37 x = x + self.resid_attn_dropout(b)
38 m = self.norm_2(x)
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/attention.py:171, in MultiheadAttention.forward(self, x, past_key_value, attn_bias, attention_mask, is_causal, needs_weights)
169 if attn_bias is not None:
170 attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
--> 171 (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
172 return (self.out_proj(context), attn_weights, past_key_value)
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/attention.py:111, in triton_flash_attn_fn(query, key, value, n_heads, softmax_scale, attn_bias, key_padding_mask, is_causal, dropout_p, training, needs_weights, multiquery)
109 value = value.expand(*value.shape[:2], n_heads, value.size(-1))
110 reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
--> 111 attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
112 output = attn_output.view(*attn_output.shape[:2], -1)
113 return (output, None)
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/flash_attn/flash_attn_triton.py:810, in FlashAttnFunc.forward(ctx, q, k, v, bias, causal, softmax_scale)
808 # Make sure that the last dimension is contiguous
809 q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
--> 810 o, lse, ctx.softmax_scale = _flash_attn_forward(
811 q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
812 )
813 ctx.save_for_backward(q, k, v, o, lse, bias)
814 ctx.causal = causal
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/flash_attn/flash_attn_triton.py:623, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale)
621 num_warps = 4 if d <= 64 else 8
622 grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
--> 623 _fwd_kernel[grid](
624 q, k, v, bias, o,
625 lse, tmp,
626 softmax_scale,
627 q.stride(0), q.stride(2), q.stride(1),
628 k.stride(0), k.stride(2), k.stride(1),
629 v.stride(0), v.stride(2), v.stride(1),
630 *bias_strides,
631 o.stride(0), o.stride(2), o.stride(1),
632 nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
633 seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
634 # Can't use kwargs here because triton autotune expects key to be args, not kwargs
635 # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
636 bias_type, causal, BLOCK_HEADDIM,
637 BLOCK_M=BLOCK, BLOCK_N=BLOCK,
638 num_warps=num_warps,
639 num_stages=1,
640 )
641 return o, lse, softmax_scale
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/triton/runtime/jit.py:106, in KernelInterface.getitem..launcher(*args, **kwargs)
105 def launcher(*args, **kwargs):
--> 106 return self.run(*args, grid=grid, **kwargs)
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/triton/runtime/autotuner.py:200, in Heuristics.run(self, *args, **kwargs)
198 for v, heur in self.values.items():
199 kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 200 return self.fn.run(*args, **kwargs)
File :43, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)
RuntimeError: Triton Error [CUDA]: invalid argument
Any ideas what might be causing this? I'm working with:
Python: 3.10
CUDA: 11.7
triton: 2.0.0.dev20221202
flash-attn: 1.0.3.post0
transformers: 4.29.2
torch: 1.13.1+cu117
Yeah, I'm getting this too
Update + additional context for this error. Was using T4 NVIDIA GPUs for above error. Switched to test on V100 GPUs with same packages/installs, and am now getting something different:
```
Briefly explain to me what the Reimann Hypothesis
/usr/bin/ld: skipping incompatible /usr/lib/libcuda.so when searching for -lcuda
is.
â-s�AN’s,
/usr/bin/ld: skipping incompatible /usr/lib/libcuda.so when searching for -lcuda
RuntimeError Traceback (most recent call last)
Cell In[46], line 8
4 with torch.inference_mode():
5 # res = nlp(text, max_new_tokens=mnt, min_new_tokens=1, return_full_text=False)
6 # inputs = input_map(inputs)
7 st = time.time()
----> 8 generate_ids = model.generate(inputs, **params, streamer=streamer)
9 # generate_ids = model.module.generate(**inputs, **params, streamer=streamer)
10 tt = time.time() - st
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.call..decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py:1565, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
1557 input_ids, model_kwargs = self._expand_inputs_for_generation(
1558 input_ids=input_ids,
1559 expand_size=generation_config.num_return_sequences,
1560 is_encoder_decoder=self.config.is_encoder_decoder,
1561 **model_kwargs,
1562 )
1564 # 13. run sample
-> 1565 return self.sample(
1566 input_ids,
1567 logits_processor=logits_processor,
1568 logits_warper=logits_warper,
1569 stopping_criteria=stopping_criteria,
1570 pad_token_id=generation_config.pad_token_id,
1571 eos_token_id=generation_config.eos_token_id,
1572 output_scores=generation_config.output_scores,
1573 return_dict_in_generate=generation_config.return_dict_in_generate,
1574 synced_gpus=synced_gpus,
1575 streamer=streamer,
1576 **model_kwargs,
1577 )
1579 elif is_beam_gen_mode:
1580 if generation_config.num_return_sequences > generation_config.num_beams:
File ~/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py:2648, in GenerationMixin.sample(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
2646 # sample
2647 probs = nn.functional.softmax(next_token_scores, dim=-1)
-> 2648 next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
2650 # finished sentences should have their next token be a padding token
2651 if eos_token_id is not None:
RuntimeError: probability tensor contains either inf
, nan
or element < 0
Maybe this can provide more context? At the beginning you can see the prompt given: "Briefly explain to me what the Reimann Hypothesis is." Spits out a gibberish token, then fails. Any thoughts?
@sam-mosaic
I'm not sure if this is the root cause, but we just added device_map
support recently: https://huggingface.co/mosaicml/mpt-7b-instruct/discussions/41
@abhi-mosaic still seeing the same issue (RuntimeError: Triton Error [CUDA]: invalid argument) with auto device map. relevant packages/installs:
Python 3.10
CUDA 11.7
4x16GB T4 GPUs
einops==0.5.0
torch==1.13.1
transformers==4.29.2
triton-pre-mlir @ git+https://github.com/vchiley/triton.git@48b1cc9ff8b1f506ac32f2124471e2582875c008#subdirectory=python
Also, testing again on 4x16GB V100 GPUs with the same installs, I get another error as noted above: "RuntimeError: probability tensor contains either inf
, nan
or element < 0"
We have only tested triton on A10s and A100s. It may not be an option on either of those GPUs.
@sam-mosaic is bfloat16 precision required for the triton implementation? T4s and V100s don't support bfloat16 precision, but I've tried with regular float16 precision as well and get the same error. so if the triton implementation can't run on regular float16 precision, then the lack of support for bfloat16 precision on those GPUs would explain this issue.