Fine-Tuning OpenFlamingo on NVIDIA H100 GPUs

Fine-Tuning OpenFlamingo on NVIDIA H100 GPUs

Author: Nguyễn Ngọc Mai
16:43 21/08/2025

1. Flamingo Introduction: Few-Shot Learning for Visual Language Models

image
DoryCredit: www.istockphoto.com

 

Flamingo (original paper: [https://arxiv.org/pdf/2204.14198]) is a family of Visual Language Models (VLMs) designed by a team in Google DeepMind to solve the challenge of few-shot learning in multimodal machine learning. The model is built with three key architectural innovations:

  • It bridges powerful, pre-trained vision-only and language-only models.
  • It can handle sequences of arbitrarily interleaved visual and textual data.
  • It can seamlessly ingest images or videos as input.

This flexibility allows Flamingo to be trained on large-scale web data with mixed images and text, which is crucial for its ability to learn new tasks with only a few examples. As a result, a single Flamingo model can achieve state-of-the-art performance on a wide range of tasks, including visual question-answering, captioning, and multiple-choice questions simply by being prompted with task-specific examples. This few-shot approach often allows Flamingo to outperform models that have been fine-tuned on thousands of times more data.

2. How Flamingo Works

0 rgZkSv4mIYJ1GCLg 1
Multimodal LLM. Credit: dataiku.com

 

Flamingo operates through a multimodal interface, processing a combination of images, videos, and textto generate relevant textual responses. This design allows it to adapt seamlessly to different tasks, functioning similarly to large language models (LLMs), which use text-based prompts to tackle diverse language-related challenges.

Model Architecture

OpenFlamingo combines a pretrained vision encoder and a language model using cross attention layers. The model architecture is shown below.
0 Jrv0c8N2
Credit: Google Deep Mind

 

The architecture can be understood by following two main pathways:

a. The Visual Pathway (Left side)

This pathway is responsible for processing the visual data (images) and preparing it for the language model.

  • Vision Encoder: This is a pre-trained model (indicated by the frozen snowflake icon) that extracts features from the input images. A key design choice is that this encoder’s weights are “frozen” and do not change during training.
  • Perceiver Resampler: The output of the Vision Encoder is then fed into the Perceiver Resampler. This module maps the variable-sized visual features to a small, fixed number of output tokens. This component is trained from scratch (indicated by the purple fill), learning to produce a concise summary of the visual data. For Flamingo, number of output image tokens are set to be 5.

b. The Language Pathway (Right Side)

This pathway processes the text and fuses it with the visual information to generate a final output.

  • Interleaved Input: The model takes an input sequence of text mixed with image placeholders (<image>).
  • LM Blocks: The core of this pathway is a large, pre-trained Language Model (LM) (like a Chinchilla model). Similar to the Vision Encoder, these blocks are “frozen,” meaning their vast knowledge of language is leveraged without needing to be retrained.
  • Gated XATTN-DENSE: This is the key innovation that connects the two pathways. These are new modules, trained from scratch, that are inserted between the LM blocks. When the model encounters an <image> placeholder in the text stream, the Gated XATTN-DENSE layer performs a cross-attention operation. It uses the text information as queries to "look at" the visual tokens generated by the Perceiver Resampler. The "gated" part is a mechanism that controls how much visual information is allowed to influence the language generation, providing a dynamic way to fuse the two modalities.

Setting a New Standard in Few-Shot Learning

Flamingo has been rigorously tested on 16 different tasks and has consistently outperformed previous few-shot learning models, even when provided with as few as four examples per task. In several cases, it has demonstrated superior performance over methods that rely on extensive fine-tuning and significantly larger datasets, highlighting its ability to generalize effectively.

By minimizing the need for large-scale annotations and task-specific retraining, Flamingo represents a significant advancement in visual language model efficiency. Its ability to learn quickly from limited examples brings AI closer to human-like adaptability, enabling a wider range of real-world applications with greater ease and accuracy.

3.Why do we finetune it?

To validate the performance of our new H100 system, we’re testing its ability to run a LLM. For this evaluation, we’ve chosen to fine-tune a community-built implementation of the Flamingo model. This project serves a dual purpose:

  • System Validation: We’re using this fine-tuning task to rigorously test our H100 infrastructure, ensuring it can handle the demanding computational requirements of training and running a large model.
  • Code Verification: Since the original Flamingo model code wasn’t publicly released, we’re relying on a community-developed version. This process allows us to verify if this open-source implementation is a faithful and runnable recreation of the model described in the research paper.

Therefore, please note that we particularly focus on our system capability here rather than the evaluation of model’s accuracy.

In this project, a Flamingo-replica known as OpenFlamingo, developed by ML-Foundation, was utilized since the original Flamingo model has not been publicly released. The objective was to fine-tune OpenFlamingo on its original dataset and evaluate its performance under controlled conditions. This experiment served two primary purposes: (1) assessing the model’s stability and reproducibility when fine-tuned on the same dataset, and (2) benchmarking its performance on an NVIDIA H100 GPUs system to analyze computational efficiency, memory usage, and overall system capability for handling large-scale multimodal tasks. These insights help determine the feasibility of deploying OpenFlamingo in practical applications while optimizing hardware utilization.

4. How did we finetune it?


Installation

To install the package in an existing environment, run

pip install open-flamingo 

or to create a conda environment for running OpenFlamingo, run

conda env create -f environment.yml 

To install training or eval dependencies, run one of the first two commands. To install everything, run the third command.

pip install open-flamingo[training]
pip install open-flamingo[eval] 
pip install open-flamingo[all]

There are three `requirements.txt` files:
- `requirements.txt`
- `requirements-training.txt`
- `requirements-eval.txt`

Depending on your use case, you can install any of these with pip install -r <requirements-file.txt>. The base file contains only the dependencies needed for running the model.

Development

Open-source authors use pre-commit hooks to align formatting with the checks in the repository.
 
To install pre-commit, run
pip install pre-commit 

or use brew for MacOS

brew install pre-commit  

Check the version installed with

pre-commit - version   

Then at the root of this repository, run

pre-commit install    

Then every time we run git commit, the checks are run. If the files are reformatted by the hooks, run

git add    

 for your changed files and

git commit    

again

Training Procedure

To train OpenFlamingo, please ensure your environment matches that of environment.yml.

Data Processing

The codebase uses WebDataset to efficiently load .tar files containing image and text sequences. We recommend resampling shards with replacement during training using the — dataset_resampled flag.

  • LAION-2B Dataset
    LAION-2B contains 2B web-scraped (image, text) pairs. Please use img2dataset to download this dataset into tar files.
  • Multimodal C4 Dataset
    OpenFlamingo trains on the full version of Multimodal C4 (MMC4), which includes 103M documents of web-scraped, interleaved image-text sequences. During training, it truncates sequences to 256 text tokens and six images per sequence. The codebase expects .tar files containing .json files, which include raw images encoded in base64.
    Scripts are provided to convert MMC4 to this format: (1) Download the MMC4 shards into .zip files using the MMC4-provided scripts (e.g., fewer_facesv2.sh). (2) Download the MMC4 raw images into an image directory using the MMC4-provided scripts (e.g., download_images.py). (3) Run scripts/convert_mmc4_to_wds.py to convert the downloaded items into the expected tar files.
  • Customized Datase
    It is reported recently that the MMC4 dataset download URLs are having some access issue. Therefore, we have made a script that helps prepare customized dataset by transforming it into MMC4’s format (we used ADNI dataset as the target for this example, with a fixed sample base64 image data). You can modify this script upon your custom dataset:
import json
import os
import tarfile

def compress_directory_to_tar(directory_path):
 json_files = [f for f in os.listdir(directory_path) if f.endswith('.json')]
 os.makedirs('replicate_mmc4', exist_ok=True)

for i in range(0, len(json_files), 20):
 batch_files = json_files[i:i+20]
 tar_file_path = os.path.join('replicate_mmc4', f"{i//20:09d}.tar")
 
with tarfile.open(tar_file_path, "w:gz") as tar:
 for file in batch_files:
 tar.add(os.path.join(directory_path, file), arcname=file)
 
print(f"Batch {i//20} compressed to {tar_file_path}")

def convert_adni_to_mmc4(input_json_path, output_folder):
 # Ensure the output folder exists
 os.makedirs(output_folder, exist_ok=True)

 # Load the large JSON file
 with open(input_json_path, 'r') as f:
 data = json.load(f)
 
matched_text_index = 0

 # Iterate over each item in the list and save it as a separate JSON file
 for idx, item in enumerate(data):
 # Ensure compatibility with the structure of f9773b9c866145c28fe0b701dde8dfbe.json
 
# Handle text list:
 conversations = item.get("conversations", None)
 if conversations is not None:
 text_list = []
 for conversation in conversations:
 text_list.append(conversation["value"])
 
# Check for &amp;amp;lt;image&amp;amp;gt; tag in the first element of conversations list
 first_convo = conversations[0]["value"]
 if "&amp;amp;lt;image&amp;amp;gt;" in first_convo:
 if first_convo.startswith("&amp;amp;lt;image&amp;amp;gt;"):
 matched_text_index = 0
 elif first_convo.endswith("&amp;amp;lt;image&amp;amp;gt;"):
 matched_text_index = 1
 
item["text_list"] = text_list
 
# Handle image's base64 content:
 with open('./sample_base64.txt', 'r') as f:
 sample_img_base64_data = f.read()
 
# Handle image info:
 img_info = []
 images_list = item.get("image", None)
 if images_list is not None:
 for img in images_list:
 img_obj = {}
 img_obj["image_name"] = img
 img_obj["raw_url"] = "https://example.com/{}".format(img)
 img_obj["matched_text_index"] = matched_text_index
 img_obj["matched_sim"] = 0.75
 img_obj["image_base64"] = sample_img_base64_data
 img_info.append(img_obj)
 
# Create similarity_matrix
 similarity_matrix = []
 for img in img_info:
 for _ in range(len(text_list)):
 inner_list = [0] * len(text_list)
 inner_list[matched_text_index] = 1
 similarity_matrix.append(inner_list)

 # item["similarity_matrix"] = similarity_matrix
 
output_item = {
 "id": item.get("id", None),
 "url": "https://example.com",
 "text_list": item.get("text_list", None),
 "image_info": img_info,
 "similarity_matrix": similarity_matrix,
 "could_have_url_duplicate": 0
 }

 # Save the item as a separate JSON file
 output_path = os.path.join(output_folder, f"{idx:05d}.json")
 with open(output_path, 'w') as out_f:
 json.dump(output_item, out_f)    

ChatGPT-generated sequences
A subset of our models (listed below) were also trained on experimental ChatGPT-generated (image, text) sequences, where images are pulled from LAION. The shards containing these sequences can be found at this CodaLab worksheet. They are unable to distribute raw images in the released shards; images must be pre-downloaded from the urls in the json files and converted to base64 before using this data for training in our codebase.
Models trained with ChatGPT-generated sequences:

  • OpenFlamingo-4B-vitl-rpj3b
  • OpenFlamingo-4B-vitl-rpj3b-langinstruct

Training Command
A sample Slurm is provided in the training script in scripts/. You can also modify the following command (which was specifically used in our case):

torchrun --nnodes=1 --nproc_per_node=8 open_flamingo/train/train.py \
 --lm_path anas-awadalla/mpt-1b-redpajama-200b \
 --tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \
 --cross_attn_every_n_layers 1 \
 --dataset_resampled \
 --batch_size_mmc4 2 \
 --train_num_samples_mmc4 1000 \
 --workers=4 \
 --run_name OpenFlamingo-3B-vitl-mpt1b \
 --num_epochs 20 \
 --warmup_steps 1875 \
 --mmc4_textsim_threshold 0.24 \
 --mmc4_shards "modifications/VLM_ADNI_DATA/replicate_mmc4/{000000000..000000040}.tar" \
 --report_to_wandb    

The MPT-1B base and instruct modeling code does not accept the `labels` kwarg or compute cross-entropy loss directly within `forward()`, as expected by our codebase. We suggest using a modified version of the MPT-1B models found here and here.

Distributed raining

By default, train.py uses Pytorch’s DistributedDataParallel for training.
To use FullyShardedDataParallel, use the — fsdp flag.

Some notes on FSDP from the OpenFlamingo team:
We recommend using the — fsdp_use_orig_params df flag. If — fsdp is on without this flag, all language model embeddings will be unfrozen during training. (In contrast, the default behavior is to only train the newly added <image> and <|endofchunk|> tokens.)
Note: We’ve encountered issues using OPT with this flag. Other language models should be compatible.
Our current FSDP wrapping strategy does not permit training language model embeddings that use tied weights (i.e., tied input/output embeddings). To train such models with FSDP, the language model embeddings must be frozen with the — freeze_lm_embeddings flag.

We also implement gradient checkpointing and mixed precision training. Use the — gradient_checkpointing and — precision arguments, respectively.

Initializing an OpenFlamingo model

OpenFlamingo supports pretrained vision encoders from the OpenCLIP package, which includes OpenAI’s pretrained models.
They also support pretrained language models from the transformers package, such as MPT, RedPajama, LLaMA, OPT, GPT-Neo, GPT-J, and Pythia models

from open_flamingo import create_model_and_transforms

model, image_processor, tokenizer = create_model_and_transforms(
 clip_vision_encoder_path="ViT-L-14",
 clip_vision_encoder_pretrained="openai",
 lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
 tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
 cross_attn_every_n_layers=1,
 cache_dir="PATH/TO/CACHE/DIR" # Defaults to ~/.cache  

5. Results

Below is the results reported from our WandBs: NVIDIA H100 GPUs
NVIDIA H100 System that was employed:

  • The system is equipped with 8*NVIDIA H100 80GB HBM3 GPUs. However, for this training setting, only 2 GPUs with distributed training are actually enough.
  • Each NVIDIA H100 has 80GB of high-bandwidth memory (HBM3), making this a high-performance computing (HPC) or AI training system.
  • The NVIDIA H100 GPUs are in P0 performance state, which indicates they are in the highest available performance mode.

1 5RgbWnCf0ZV1090BpUDlNw 1

Model’s reported metrics

1 6OW4AQZ lWFuPSD1X hSFg 1
Credit: Wandbs.com

 

The training metrics indicate a well-functioning process with expected behaviors across various parameters. The loss curve shows a sharp initial drop before stabilizing, suggesting good convergence. The learning ratefollows a linear warm-up schedule, which is a common practice to stabilize early training. Step time and data loading times remain mostly consistent, with occasional spikes that may be caused by system fluctuations, checkpointing, or data fetching delays. The global step progresses linearly, confirming steady training iteration increments. The samples per second per GPU metric remains stable, with a minor dip that does not appear to significantly impact performance. Overall, these metrics suggest normal training behavior, though monitoring occasional spikes in step time and data time could help optimize efficiency further.
System’s reported metrics (what we care more):
1 KkUEMdZ9ldYWAm5g1Uw7kw
Credit: Wandbs.com

 

  • GPU Uncorrected Memory Errors (Top-left): The line remains at zero, indicating no uncorrected memory errors.
  • GPU Corrected Memory Errors (Top-middle): The plot is also flat at zero, meaning no corrected memory errors.
  • GPU Memory Clock Speed (Top-right): Normal; consistent clock speed suggests no dynamic frequency scaling or throttling.
  • GPU Streaming Multiprocessor (SM) Clock Speed (Bottom-left): Normal; stable clock speed suggests no thermal throttling.
  • GPU Power Usage (W) (Bottom-middle): Shows a cyclical pattern, indicating the GPU power consumption fluctuates during workload execution => could be due to batch processing, workload scheduling, or dynamic power management.
1 deY1DrBIg6ZCQBWCLWlhgQ
Credit: Wandbs.coM

 

  • GPU Enforced Power Limit (W) (Top-left): Normal; this indicates that the GPU is not exceeding its predefined power limit.
  • GPU Memory Allocated (Bytes) (Top-middle): Memory allocation remains stable but drops suddenly at the end => The drop is at when training finished.
  • GPU Memory Allocated (%) (Top-right): Normal, same as GPU Memory Allocated (Bytes).
  • GPU Time Spent Accessing Memory (%) (Bottom-left): Correlate with GPU Power Usage (W) above.
  • GPU Temperature (°C) (Bottom-middle): Correlate with GPU Power Usage (W) above.
  • GPU Utilization (%) (Bottom-right): Correlate with GPU Power Usage (W) above.