Configuring Command-R for long context tasks

#32
by beam-me-up-scotty - opened

Apologies for the duplicate post, but the previous related discussion was unclear to me.

saurabhdash mentions:

"This implementation is based on the Llama implementation which materializes this huge buffer which would not be feasible for 128k context. The model does support 128k context with a better implementation."

and then gives the following line of python:

causal_mask = torch.full( (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool )

What exact steps do we need to follow to implement this?

I've tried editing the max_position_embeddings directly in the config.json, and can only run a 25k prompt with max_position_embeddings=32768 and 8 bit quant using a machine with 2x A100 (approx 160GB VRAM).

Can someone indicate how this default implementation needs to change to use the better implementation mentioned above:

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_id = "CohereForAI/c4ai-command-r-v01"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, device_map="auto", quantization_config=bnb_config)

Hi! Apart from the materialized attention mask, there is another problem -- the logits are up-casted to fp32. If you have a seq length of 128k, the logits themselves would take up 128k * 256k * 4(bytes)= 131GB. If the goal is to use it for generation, one could get rid of this and just do log-softmax over the last token's logits.

Thanks for your answer @saurabhdash ! In terms of implementation:

  • Would the implementation of causal_mask at line 614 of modeling_cohere.py in forward() need to change to your above implementation?
  • Where would you change the implementation of the logits? Any tips about how to do so?
  • What's a reasonable VRAM usage to expect for a 128k task with these optimisations? Am I over-optimistic to think that we can fit a context of that size on 2x A100s?

Apologies if these are silly questions, still a little new to all this

Cohere For AI org

I'd recommend waiting for/ using the vLLM implementation. That should be able to help you scale the context to the maximum.

alexrs changed discussion status to closed

Sign up or log in to comment