TRL documentation

Reducing Memory Usage

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.13.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Reducing Memory Usage

Section under construction. Feel free to contribute!

Truncation

Sequence lengths in the dataset can vary widely. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short.

Truncation prompt completion

To reduce memory usage, it’s important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.

DPO
SFT

DPO truncation is applied first to the prompt and to the completion via the max_prompt_length and max_completion_length parameters. The max_length parameter is then used to truncate the resulting sequence.

Truncation prompt completion

To set the truncation parameters, use the following code snippet:

from trl import DPOConfig

training_args = DPOConfig(..., max_prompt_length=..., max_length=...)

You can also use the max_completion_length parameter to truncate the completion, though this is less common since the goal is typically to preserve the completion’s full length whenever possible.

from trl import DPOConfig

training_args = DPOConfig(..., max_completion_length=...)

Packing

This technique applies only to SFT.

Truncation has several drawbacks:

  1. Loss of information: Key data at the end of a sequence may be discarded.
  2. Choosing truncation length: Too short loses data; too long undermines efficiency.

Packing, introduced in Raffel et al., 2020, addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths.

Packing

Packing eliminates padding, preserves all sequence information, and allows for flexible sequence lengths, making it a more efficient alternative to truncation. To enable packing, use packing=True in the SFTConfig:

from trl import SFTConfig

training_args = SFTConfig(..., packing=True, max_seq_length=512)

Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see #1230.

Disabling model gathering for generation in online methods

When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: #2250.

If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter:

Online DPO
PPO
RLOO
from trl import OnlineDPOConfig

training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False)

This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds.

< > Update on GitHub