1. Flamingo Introduction: Few-Shot Learning for Visual Language Models
[caption id="attachment_65686" align="aligncenter" width="800"] DoryCredit: www.istockphoto.com[/caption]
Flamingo (original paper: [https://arxiv.org/pdf/2204.14198]) is a family of Visual Language Models (VLMs) designed by a team in Google DeepMind to solve the challenge of few-shot learning in multimodal machine learning. The model is built with three key architectural innovations:
It bridges powerful, pre-trained vision-only and language-only models.
It can handle sequences of arbitrarily interleaved visual and textual data.
It can seamlessly ingest images or videos as input.
This flexibility allows Flamingo to be trained on large-scale web data with mixed images and text, which is crucial for its ability to learn new tasks with only a few examples. As a result, a single Flamingo model can achieve state-of-the-art performance on a wide range of tasks, including visual question-answering, captioning, and multiple-choice questions simply by being prompted with task-specific examples. This few-shot approach often allows Flamingo to outperform models that have been fine-tuned on thousands of times more data.
2. How Flamingo Works
[caption id="attachment_65677" align="aligncenter" width="700"] Multimodal LLM. Credit: dataiku.com[/caption]
Flamingo operates through a multimodal interface, processing a combination of images, videos, and textto generate relevant textual responses. This design allows it to adapt seamlessly to different tasks, functioning similarly to large language models (LLMs), which use text-based prompts to tackle diverse language-related challenges.
Model Architecture
OpenFlamingo combines a pretrained vision encoder and a language model using cross attention layers. The model architecture is shown below.
[caption id="attachment_65678" align="aligncenter" width="960"] Credit: Google Deep Mind[/caption]
The architecture can be understood by following two main pathways:
a. The Visual Pathway (Left side)
This pathway is responsible for processing the visual data (images) and preparing it for the language model.
Vision Encoder: This is a pre-trained model (indicated by the frozen snowflake icon) that extracts features from the input images. A key design choice is that this encoder’s weights are “frozen” and do not change during training.
Perceiver Resampler: The output of the Vision Encoder is then fed into the Perceiver Resampler. This module maps the variable-sized visual features to a small, fixed number of output tokens. This component is trained from scratch (indicated by the purple fill), learning to produce a concise summary of the visual data. For Flamingo, number of output image tokens are set to be 5.
b. The Language Pathway (Right Side)
This pathway processes the text and fuses it with the visual information to generate a final output.
Interleaved Input: The model takes an input sequence of text mixed with image placeholders (<image>).
LM Blocks: The core of this pathway is a large, pre-trained Language Model (LM) (like a Chinchilla model). Similar to the Vision Encoder, these blocks are “frozen,” meaning their vast knowledge of language is leveraged without needing to be retrained.
Gated XATTN-DENSE: This is the key innovation that connects the two pathways. These are new modules, trained from scratch, that are inserted between the LM blocks. When the model encounters an <image> placeholder in the text stream, the Gated XATTN-DENSE layer performs a cross-attention operation. It uses the text information as queries to "look at" the visual tokens generated by the Perceiver Resampler. The "gated" part is a mechanism that controls how much visual information is allowed to influence the language generation, providing a dynamic way to fuse the two modalities.
Setting a New Standard in Few-Shot Learning
Flamingo has been rigorously tested on 16 different tasks and has consistently outperformed previous few-shot learning models, even when provided with as few as four examples per task. In several cases, it has demonstrated superior performance over methods that rely on extensive fine-tuning and significantly larger datasets, highlighting its ability to generalize effectively.
By minimizing the need for large-scale annotations and task-specific retraining, Flamingo represents a significant advancement in visual language model efficiency. Its ability to learn quickly from limited examples brings AI closer to human-like adaptability, enabling a wider range of real-world applications with greater ease and accuracy.
3.Why do we finetune it?
To validate the performance of our new H100 system, we’re testing its ability to run a LLM. For this evaluation, we’ve chosen to fine-tune a community-built implementation of the Flamingo model. This project serves a dual purpose:
System Validation: We’re using this fine-tuning task to rigorously test our H100 infrastructure, ensuring it can handle the demanding computational requirements of training and running a large model.
Code Verification: Since the original Flamingo model code wasn’t publicly released, we’re relying on a community-developed version. This process allows us to verify if this open-source implementation is a faithful and runnable recreation of the model described in the research paper.
Therefore, please note that we particularly focus on our system capability here rather than the evaluation of model’s accuracy.
In this project, a Flamingo-replica known as OpenFlamingo, developed by ML-Foundation, was utilized since the original Flamingo model has not been publicly released. The objective was to fine-tune OpenFlamingo on its original dataset and evaluate its performance under controlled conditions. This experiment served two primary purposes: (1) assessing the model’s stability and reproducibility when fine-tuned on the same dataset, and (2) benchmarking its performance on an NVIDIA H100 GPUs system to analyze computational efficiency, memory usage, and overall system capability for handling large-scale multimodal tasks. These insights help determine the feasibility of deploying OpenFlamingo in practical applications while optimizing hardware utilization.
4. How did we finetune it?
Installation
To install the package in an existing environment, run
[code lang="js"]
pip install open-flamingo
[/code]
or to create a conda environment for running OpenFlamingo, run
[code lang="js"]
conda env create -f environment.yml
[/code]
To install training or eval dependencies, run one of the first two commands. To install everything, run the third command.
[code lang="js"]
pip install open-flamingo[training]
pip install open-flamingo[eval]
pip install open-flamingo[all]
[/code]
There are three `requirements.txt` files:
- `requirements.txt`
- `requirements-training.txt`
- `requirements-eval.txt`
Depending on your use case, you can install any of these with pip install -r <requirements-file.txt>. The base file contains only the dependencies needed for running the model.
Development
Open-source authors use pre-commit hooks to align formatting with the checks in the repository.
To install pre-commit, run
[code lang="js"]
pip install pre-commit
[/code]
or use brew for MacOS
[code lang="js"]
brew install pre-commit
[/code]
Check the version installed with
[code lang="js"]
pre-commit - version
[/code]
Then at the root of this repository, run
[code lang="js"]
pre-commit install
[/code]
Then every time we run git commit, the checks are run. If the files are reformatted by the hooks, run
[code lang="js"]
git add
[/code]
for your changed files and
[code lang="js"]
git commit
[/code]
again
Training Procedure
To train OpenFlamingo, please ensure your environment matches that of environment.yml.
Data Processing
The codebase uses WebDataset to efficiently load .tar files containing image and text sequences. We recommend resampling shards with replacement during training using the — dataset_resampled flag.
LAION-2B Dataset
LAION-2B contains 2B web-scraped (image, text) pairs. Please use img2dataset to download this dataset into tar files.
Multimodal C4 Dataset
OpenFlamingo trains on the full version of Multimodal C4 (MMC4), which includes 103M documents of web-scraped, interleaved image-text sequences. During training, it truncates sequences to 256 text tokens and six images per sequence. The codebase expects .tar files containing .json files, which include raw images encoded in base64.
Scripts are provided to convert MMC4 to this format: (1) Download the MMC4 shards into .zip files using the MMC4-provided scripts (e.g., fewer_facesv2.sh). (2) Download the MMC4 raw images into an image directory using the MMC4-provided scripts (e.g., download_images.py). (3) Run scripts/convert_mmc4_to_wds.py to convert the downloaded items into the expected tar files.
Customized Datase
It is reported recently that the MMC4 dataset download URLs are having some access issue. Therefore, we have made a script that helps prepare customized dataset by transforming it into MMC4’s format (we used ADNI dataset as the target for this example, with a fixed sample base64 image data). You can modify this script upon your custom dataset:
[code lang="js"]
import json
import os
import tarfile
def compress_directory_to_tar(directory_path):
json_files = [f for f in os.listdir(directory_path) if f.endswith('.json')]
os.makedirs('replicate_mmc4', exist_ok=True)
for i in range(0, len(json_files), 20):
batch_files = json_files[i:i+20]
tar_file_path = os.path.join('replicate_mmc4', f"{i//20:09d}.tar")
with tarfile.open(tar_file_path, "w:gz") as tar:
for file in batch_files:
tar.add(os.path.join(directory_path, file), arcname=file)
print(f"Batch {i//20} compressed to {tar_file_path}")
def convert_adni_to_mmc4(input_json_path, output_folder):
# Ensure the output folder exists
os.makedirs(output_folder, exist_ok=True)
# Load the large JSON file
with open(input_json_path, 'r') as f:
data = json.load(f)
matched_text_index = 0
# Iterate over each item in the list and save it as a separate JSON file
for idx, item in enumerate(data):
# Ensure compatibility with the structure of f9773b9c866145c28fe0b701dde8dfbe.json
# Handle text list:
conversations = item.get("conversations", None)
if conversations is not None:
text_list = []
for conversation in conversations:
text_list.append(conversation["value"])
# Check for &amp;lt;image&amp;gt; tag in the first element of conversations list
first_convo = conversations[0]["value"]
if "&amp;lt;image&amp;gt;" in first_convo:
if first_convo.startswith("&amp;lt;image&amp;gt;"):
matched_text_index = 0
elif first_convo.endswith("&amp;lt;image&amp;gt;"):
matched_text_index = 1
item["text_list"] = text_list
# Handle image's base64 content:
with open('./sample_base64.txt', 'r') as f:
sample_img_base64_data = f.read()
# Handle image info:
img_info = []
images_list = item.get("image", None)
if images_list is not None:
for img in images_list:
img_obj = {}
img_obj["image_name"] = img
img_obj["raw_url"] = "https://example.com/{}".format(img)
img_obj["matched_text_index"] = matched_text_index
img_obj["matched_sim"] = 0.75
img_obj["image_base64"] = sample_img_base64_data
img_info.append(img_obj)
# Create similarity_matrix
similarity_matrix = []
for img in img_info:
for _ in range(len(text_list)):
inner_list = [0] * len(text_list)
inner_list[matched_text_index] = 1
similarity_matrix.append(inner_list)
# item["similarity_matrix"] = similarity_matrix
output_item = {
"id": item.get("id", None),
"url": "https://example.com",
"text_list": item.get("text_list", None),
"image_info": img_info,
"similarity_matrix": similarity_matrix,
"could_have_url_duplicate": 0
}
# Save the item as a separate JSON file
output_path = os.path.join(output_folder, f"{idx:05d}.json")
with open(output_path, 'w') as out_f:
json.dump(output_item, out_f)
[/code]
ChatGPT-generated sequences
A subset of our models (listed below) were also trained on experimental ChatGPT-generated (image, text) sequences, where images are pulled from LAION. The shards containing these sequences can be found at this CodaLab worksheet. They are unable to distribute raw images in the released shards; images must be pre-downloaded from the urls in the json files and converted to base64 before using this data for training in our codebase.
Models trained with ChatGPT-generated sequences:
OpenFlamingo-4B-vitl-rpj3b
OpenFlamingo-4B-vitl-rpj3b-langinstruct
Training Command
A sample Slurm is provided in the training script in scripts/. You can also modify the following command (which was specifically used in our case):
[code lang="js"]
torchrun --nnodes=1 --nproc_per_node=8 open_flamingo/train/train.py \
--lm_path anas-awadalla/mpt-1b-redpajama-200b \
--tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \
--cross_attn_every_n_layers 1 \
--dataset_resampled \
--batch_size_mmc4 2 \
--train_num_samples_mmc4 1000 \
--workers=4 \
--run_name OpenFlamingo-3B-vitl-mpt1b \
--num_epochs 20 \
--warmup_steps 1875 \
--mmc4_textsim_threshold 0.24 \
--mmc4_shards "modifications/VLM_ADNI_DATA/replicate_mmc4/{000000000..000000040}.tar" \
--report_to_wandb
[/code]
The MPT-1B base and instruct modeling code does not accept the `labels` kwarg or compute cross-entropy loss directly within `forward()`, as expected by our codebase. We suggest using a modified version of the MPT-1B models found here and here.
Distributed raining
By default, train.py uses Pytorch’s DistributedDataParallel for training.
To use FullyShardedDataParallel, use the — fsdp flag.
Some notes on FSDP from the OpenFlamingo team:
We recommend using the — fsdp_use_orig_params df flag. If — fsdp is on without this flag, all language model embeddings will be unfrozen during training. (In contrast, the default behavior is to only train the newly added <image> and <|endofchunk|> tokens.)
Note: We’ve encountered issues using OPT with this flag. Other language models should be compatible.
Our current FSDP wrapping strategy does not permit training language model embeddings that use tied weights (i.e., tied input/output embeddings). To train such models with FSDP, the language model embeddings must be frozen with the — freeze_lm_embeddings flag.
We also implement gradient checkpointing and mixed precision training. Use the — gradient_checkpointing and — precision arguments, respectively.
Initializing an OpenFlamingo model
OpenFlamingo supports pretrained vision encoders from the OpenCLIP package, which includes OpenAI’s pretrained models.
They also support pretrained language models from the transformers package, such as MPT, RedPajama, LLaMA, OPT, GPT-Neo, GPT-J, and Pythia models
[code lang="js"]
from open_flamingo import create_model_and_transforms
model, image_processor, tokenizer = create_model_and_transforms(
clip_vision_encoder_path="ViT-L-14",
clip_vision_encoder_pretrained="openai",
lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
cross_attn_every_n_layers=1,
cache_dir="PATH/TO/CACHE/DIR" # Defaults to ~/.cache
[/code]
5. Results
Below is the results reported from our WandBs: NVIDIA H100 GPUs
NVIDIA H100 System that was employed:
The system is equipped with 8*NVIDIA H100 80GB HBM3 GPUs. However, for this training setting, only 2 GPUs with distributed training are actually enough.
Each NVIDIA H100 has 80GB of high-bandwidth memory (HBM3), making this a high-performance computing (HPC) or AI training system.
The NVIDIA H100 GPUs are in P0 performance state, which indicates they are in the highest available performance mode.
Model’s reported metrics
[caption id="attachment_65687" align="aligncenter" width="960"] Credit: Wandbs.com[/caption]
The training metrics indicate a well-functioning process with expected behaviors across various parameters. The loss curve shows a sharp initial drop before stabilizing, suggesting good convergence. The learning ratefollows a linear warm-up schedule, which is a common practice to stabilize early training. Step time and data loading times remain mostly consistent, with occasional spikes that may be caused by system fluctuations, checkpointing, or data fetching delays. The global step progresses linearly, confirming steady training iteration increments. The samples per second per GPU metric remains stable, with a minor dip that does not appear to significantly impact performance. Overall, these metrics suggest normal training behavior, though monitoring occasional spikes in step time and data time could help optimize efficiency further.
System’s reported metrics (what we care more):
[caption id="attachment_65688" align="aligncenter" width="960"] Credit: Wandbs.com[/caption]
GPU Uncorrected Memory Errors (Top-left): The line remains at zero, indicating no uncorrected memory errors.
GPU Corrected Memory Errors (Top-middle): The plot is also flat at zero, meaning no corrected memory errors.
GPU Memory Clock Speed (Top-right): Normal; consistent clock speed suggests no dynamic frequency scaling or throttling.
GPU Streaming Multiprocessor (SM) Clock Speed (Bottom-left): Normal; stable clock speed suggests no thermal throttling.
GPU Power Usage (W) (Bottom-middle): Shows a cyclical pattern, indicating the GPU power consumption fluctuates during workload execution => could be due to batch processing, workload scheduling, or dynamic power management.
[caption id="attachment_65689" align="aligncenter" width="960"] Credit: Wandbs.coM[/caption]
GPU Enforced Power Limit (W) (Top-left): Normal; this indicates that the GPU is not exceeding its predefined power limit.
GPU Memory Allocated (Bytes) (Top-middle): Memory allocation remains stable but drops suddenly at the end => The drop is at when training finished.
GPU Memory Allocated (%) (Top-right): Normal, same as GPU Memory Allocated (Bytes).
GPU Time Spent Accessing Memory (%) (Bottom-left): Correlate with GPU Power Usage (W) above.
GPU Temperature (°C) (Bottom-middle): Correlate with GPU Power Usage (W) above.
GPU Utilization (%) (Bottom-right): Correlate with GPU Power Usage (W) above.