Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Open In Colab

Fine-tuning SmolVLM with TRL on a consumer GPU

Authored by: Sergio Paniego

In this recipe, weโ€™ll demonstrate how to fine-tune a smol ๐Ÿค Vision Language Model (VLM) using the Hugging Face ecosystem, leveraging the powerful Transformer Reinforcement Learning library (TRL). This step-by-step guide will enable you to customize VLMs for your specific tasks, even on consumer GPUs.

๐ŸŒŸ Model & Dataset Overview

In this notebook, we will fine-tune the SmolVLM model using the ChartQA dataset. SmolVLM is a highly performant and memory-efficient model, making it an ideal choice for this task. The ChartQA dataset contains images of various chart types paired with question-answer pairs, offering a valuable resource for enhancing the modelโ€™s visual question-answering (VQA) capabilities. These skills are crucial for a range of practical applications, including data analysis, business intelligence, and educational tools.

๐Ÿ’ก Note: The instruct model we are fine-tuning has already been trained on this dataset, so it is familiar with the data. However, this serves as a valuable educational exercise for understanding fine-tuning techniques. For a complete list of datasets used to train this model, check out this document.

๐Ÿ“– Additional Resources

Expand your knowledge of Vision Language Models and related tools with these resources:

With these resources, youโ€™ll be equipped to dive deeper into the world of VLMs and push the boundaries of what they can achieve!

This notebook is tested using a L4 GPU.

Smol VLMs comparison

1. Install Dependencies

Letโ€™s start by installing the essential libraries weโ€™ll need for fine-tuning! ๐Ÿš€

!pip install  -U -q transformers trl datasets bitsandbytes peft accelerate
# Tested with transformers==4.46.3, trl==0.12.1, datasets==3.1.0, bitsandbytes==0.45.0, peft==0.13.2, accelerate==1.1.1
!pip install -q flash-attn --no-build-isolation

Authenticate with your Hugging Face account to save and share your model directly from this notebook ๐Ÿ—๏ธ.

from huggingface_hub import notebook_login

notebook_login()

2. Load Dataset ๐Ÿ“

Weโ€™ll load the HuggingFaceM4/ChartQA dataset, which provides chart images along with corresponding questions and answersโ€”perfect for fine-tuning visual question-answering models.

Weโ€™ll create a system message to make the VLM act as a chart analysis expert, giving concise answers about chart images.

system_message = """You are a Vision Language Model specialized in interpreting visual data from chart images.
Your task is to analyze the provided chart image and respond to queries with concise answers, usually a single word, number, or short phrase.
The charts include a variety of types (e.g., line charts, bar charts) and contain colors, labels, and text.
Focus on delivering accurate, succinct answers based on the visual information. Avoid additional explanation unless absolutely necessary."""

Weโ€™ll format the dataset into a chatbot structure, with the system message, image, user query, and answer for each interaction.

๐Ÿ’กFor more tips on using this model, check out the Model Card.

def format_data(sample):
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": sample["image"],
                },
                {
                    "type": "text",
                    "text": sample["query"],
                },
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": sample["label"][0]}],
        },
    ]

For educational purposes, weโ€™ll load only 10% of each split in the dataset. In a real-world scenario, you would load the entire dataset.

from datasets import load_dataset

dataset_id = "HuggingFaceM4/ChartQA"
train_dataset, eval_dataset, test_dataset = load_dataset(dataset_id, split=["train[:10%]", "val[:10%]", "test[:10%]"])

Letโ€™s take a look at the dataset structure. It includes an image, a query, a label (the answer), and a fourth feature that weโ€™ll be discarding.

train_dataset

Now, letโ€™s format the data using the chatbot structure. This will set up the interactions for the model.

train_dataset = [format_data(sample) for sample in train_dataset]
eval_dataset = [format_data(sample) for sample in eval_dataset]
test_dataset = [format_data(sample) for sample in test_dataset]
train_dataset[200]

3. Load Model and Check Performance! ๐Ÿค”

Now that weโ€™ve loaded the dataset, itโ€™s time to load the HuggingFaceTB/SmolVLM-Instruct, a 2B parameter Vision Language Model (VLM) that offers state-of-the-art (SOTA) performance while being efficient in terms of memory usage.

For a broader comparison of state-of-the-art VLMs, explore the WildVision Arena and the OpenVLM Leaderboard, where you can find the best-performing models across various benchmarks.

updated_fine_tuning_smol_vlm_diagram.png

import torch
from transformers import Idefics3ForConditionalGeneration, AutoProcessor

model_id = "HuggingFaceTB/SmolVLM-Instruct"

Next, weโ€™ll load the model and the tokenizer to prepare for inference.

model = Idefics3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2",
)

processor = AutoProcessor.from_pretrained(model_id)

To evaluate the modelโ€™s performance, weโ€™ll use a sample from the dataset. First, letโ€™s inspect the internal structure of this sample to understand how the data is organized.

train_dataset[1]

Weโ€™ll use the sample without the system message to assess the VLMโ€™s raw understanding. Hereโ€™s the input we will use:

train_dataset[1][1:2]

Now, letโ€™s take a look at the chart corresponding to the sample. Can you answer the query based on the visual information?

>>> train_dataset[1][1]["content"][0]["image"]

Letโ€™s create a method that takes the model, processor, and sample as inputs to generate the modelโ€™s answer. This will allow us to streamline the inference process and easily evaluate the VLMโ€™s performance.

def generate_text_from_sample(model, processor, sample, max_new_tokens=1024, device="cuda"):
    # Prepare the text input by applying the chat template
    text_input = processor.apply_chat_template(
        sample[1:2], add_generation_prompt=True  # Use the sample without the system message
    )

    image_inputs = []
    image = sample[1]["content"][0]["image"]
    if image.mode != "RGB":
        image = image.convert("RGB")
    image_inputs.append([image])

    # Prepare the inputs for the model
    model_inputs = processor(
        # text=[text_input],
        text=text_input,
        images=image_inputs,
        return_tensors="pt",
    ).to(
        device
    )  # Move inputs to the specified device

    # Generate text with the model
    generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    # Trim the generated ids to remove the input ids
    trimmed_generated_ids = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)]

    # Decode the output text
    output_text = processor.batch_decode(
        trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    return output_text[0]  # Return the first decoded output text
output = generate_text_from_sample(model, processor, train_dataset[1])
output

It seems like the model is referencing the wrong line, causing it to fail. To improve its performance, we can fine-tune the model with more relevant data to ensure it better understands the context and provides more accurate responses.

Remove Model and Clean GPU

Before we proceed with training the model in the next section, letโ€™s clear the current variables and clean the GPU to free up resources.

>>> import gc
>>> import time


>>> def clear_memory():
...     # Delete variables if they exist in the current global scope
...     if "inputs" in globals():
...         del globals()["inputs"]
...     if "model" in globals():
...         del globals()["model"]
...     if "processor" in globals():
...         del globals()["processor"]
...     if "trainer" in globals():
...         del globals()["trainer"]
...     if "peft_model" in globals():
...         del globals()["peft_model"]
...     if "bnb_config" in globals():
...         del globals()["bnb_config"]
...     time.sleep(2)

...     # Garbage collection and clearing CUDA memory
...     gc.collect()
...     time.sleep(2)
...     torch.cuda.empty_cache()
...     torch.cuda.synchronize()
...     time.sleep(2)
...     gc.collect()
...     time.sleep(2)

...     print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
...     print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


>>> clear_memory()
GPU allocated memory: 0.01 GB
GPU reserved memory: 0.06 GB

4. Fine-Tune the Model using TRL

4.1 Load the Quantized Model for Training โš™๏ธ

Next, weโ€™ll load the quantized model using bitsandbytes. If you want to learn more about quantization, check out this blog post or this one.

from transformers import BitsAndBytesConfig

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = Idefics3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    _attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)

4.2 Set Up QLoRA and SFTConfig ๐Ÿš€

Next, weโ€™ll configure QLoRA for our training setup. QLoRA allows efficient fine-tuning of large models by reducing the memory footprint. Unlike traditional LoRA, which uses low-rank approximation, QLoRA further quantizes the LoRA adapter weights, leading to even lower memory usage and faster training.

To boost efficiency, we can also leverage a paged optimizer or 8-bit optimizer during QLoRA implementation. This approach enhances memory efficiency and speeds up computations, making it ideal for optimizing our model without sacrificing performance.

>>> from peft import LoraConfig, get_peft_model

>>> # Configure LoRA
>>> peft_config = LoraConfig(
...     r=8,
...     lora_alpha=8,
...     lora_dropout=0.1,
...     target_modules=["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"],
...     use_dora=True,
...     init_lora_weights="gaussian",
... )

>>> # Apply PEFT model adaptation
>>> peft_model = get_peft_model(model, peft_config)

>>> # Print trainable parameters
>>> peft_model.print_trainable_parameters()
trainable params: 11,269,248 || all params: 2,257,542,128 || trainable%: 0.4992

We will use Supervised Fine-Tuning (SFT) to improve our modelโ€™s performance on the specific task. To achieve this, weโ€™ll define the training arguments with the SFTConfig class from the TRL library. SFT leverages labeled data to help the model generate more accurate responses, adapting it to the task. This approach enhances the modelโ€™s ability to understand and respond to visual queries more effectively.

from trl import SFTConfig

# Configure training arguments using SFTConfig
training_args = SFTConfig(
    output_dir="smolvlm-instruct-trl-sft-ChartQA",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=25,
    save_strategy="steps",
    save_steps=25,
    save_total_limit=1,
    optim="adamw_torch_fused",
    bf16=True,
    push_to_hub=True,
    report_to="tensorboard",
    remove_unused_columns=False,
    gradient_checkpointing=True,
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
)

4.3 Training the Model ๐Ÿƒ

To ensure that the data is correctly structured for the model during training, we need to define a collator function. This function will handle the formatting and batching of our dataset inputs, ensuring the data is properly aligned for training.

๐Ÿ‘‰ For more details, check out the official TRL example scripts.

image_token_id = processor.tokenizer.additional_special_tokens_ids[
    processor.tokenizer.additional_special_tokens.index("<image>")
]


def collate_fn(examples):
    texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]

    image_inputs = []
    for example in examples:
        image = example[1]["content"][0]["image"]
        if image.mode != "RGB":
            image = image.convert("RGB")
        image_inputs.append([image])

    batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels
    labels[labels == image_token_id] = -100  # Mask image token IDs in labels

    batch["labels"] = labels

    return batch

Now, we will define the SFTTrainer, which is a wrapper around the transformers.Trainer class and inherits its attributes and methods. This class simplifies the fine-tuning process by properly initializing the PeftModel when a PeftConfig object is provided. By using SFTTrainer, we can efficiently manage the training workflow and ensure a smooth fine-tuning experience for our Vision Language Model.

from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    tokenizer=processor.tokenizer,
)

Time to Train the Model! ๐ŸŽ‰

trainer.train()

Letโ€™s save the results ๐Ÿ’พ

trainer.save_model(training_args.output_dir)

5. Testing the Fine-Tuned Model ๐Ÿ”

Now that our Vision Language Model (VLM) is fine-tuned, itโ€™s time to evaluate its performance! In this section, weโ€™ll test the model using examples from the ChartQA dataset to assess how accurately it answers questions based on chart images. Letโ€™s dive into the results and see how well it performs! ๐Ÿš€

Letโ€™s clean up the GPU memory to ensure optimal performance ๐Ÿงน

>>> clear_memory()
GPU allocated memory: 16.34 GB
GPU reserved memory: 18.69 GB

We will reload the base model using the same pipeline as before.

model = Idefics3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2",
)

processor = AutoProcessor.from_pretrained(model_id)

We will attach the trained adapter to the pretrained model. This adapter contains the fine-tuning adjustments made during training, enabling the base model to leverage the new knowledge while keeping its core parameters intact. By integrating the adapter, we enhance the modelโ€™s capabilities without altering its original structure.

adapter_path = "sergiopaniego/smolvlm-instruct-trl-sft-ChartQA"
model.load_adapter(adapter_path)

Letโ€™s evaluate the model on an unseen sample.

test_dataset[20][:2]
>>> test_dataset[20][1]["content"][0]["image"]
output = generate_text_from_sample(model, processor, test_dataset[20])
output

The model has successfully learned to respond to the queries as specified in the dataset. Weโ€™ve achieved our goal! ๐ŸŽ‰โœจ

๐Ÿ’ป Iโ€™ve developed an example application to test the model, which you can find here. You can easily compare it with another Space featuring the pre-trained model, available here.

from IPython.display import IFrame

IFrame(src="https://sergiopaniego-smolvlm-trl-sft-chartqa.hf.space", width=1000, height=800)

6. Continuing the Learning Journey ๐Ÿง‘โ€๐ŸŽ“๏ธ

To further enhance your skills with multimodal models, I recommend checking out the resources shared at the beginning of this notebook or revisiting the section with the same name in Fine-Tuning a Vision Language Model (Qwen2-VL-7B) with the Hugging Face Ecosystem (TRL).

These resources will help deepen your knowledge and expertise in multimodal learning.

< > Update on GitHub