Fine-Tuning OpenFlamingo on NVIDIA H100 GPUs
Table of Contents
Flamingo (original paper: [https://arxiv.org/pdf/2204.14198]) is a family of Visual Language Models (VLMs) designed by a team in Google DeepMind to solve the challenge of few-shot learning in multimodal machine learning. The model is built with three key architectural innovations:
This flexibility allows Flamingo to be trained on large-scale web data with mixed images and text, which is crucial for its ability to learn new tasks with only a few examples. As a result, a single Flamingo model can achieve state-of-the-art performance on a wide range of tasks, including visual question-answering, captioning, and multiple-choice questions simply by being prompted with task-specific examples. This few-shot approach often allows Flamingo to outperform models that have been fine-tuned on thousands of times more data.
Model Architecture
The architecture can be understood by following two main pathways:
a. The Visual Pathway (Left side)
This pathway is responsible for processing the visual data (images) and preparing it for the language model.
b. The Language Pathway (Right Side)
This pathway processes the text and fuses it with the visual information to generate a final output.
Setting a New Standard in Few-Shot Learning
Flamingo has been rigorously tested on 16 different tasks and has consistently outperformed previous few-shot learning models, even when provided with as few as four examples per task. In several cases, it has demonstrated superior performance over methods that rely on extensive fine-tuning and significantly larger datasets, highlighting its ability to generalize effectively.
By minimizing the need for large-scale annotations and task-specific retraining, Flamingo represents a significant advancement in visual language model efficiency. Its ability to learn quickly from limited examples brings AI closer to human-like adaptability, enabling a wider range of real-world applications with greater ease and accuracy.
To validate the performance of our new H100 system, we’re testing its ability to run a LLM. For this evaluation, we’ve chosen to fine-tune a community-built implementation of the Flamingo model. This project serves a dual purpose:
Therefore, please note that we particularly focus on our system capability here rather than the evaluation of model’s accuracy.
In this project, a Flamingo-replica known as OpenFlamingo, developed by ML-Foundation, was utilized since the original Flamingo model has not been publicly released. The objective was to fine-tune OpenFlamingo on its original dataset and evaluate its performance under controlled conditions. This experiment served two primary purposes: (1) assessing the model’s stability and reproducibility when fine-tuned on the same dataset, and (2) benchmarking its performance on an NVIDIA H100 GPUs system to analyze computational efficiency, memory usage, and overall system capability for handling large-scale multimodal tasks. These insights help determine the feasibility of deploying OpenFlamingo in practical applications while optimizing hardware utilization.
Installation
To install the package in an existing environment, run
pip install open-flamingo
or to create a conda environment for running OpenFlamingo, run
conda env create -f environment.yml
To install training or eval dependencies, run one of the first two commands. To install everything, run the third command.
pip install open-flamingo[training] pip install open-flamingo[eval] pip install open-flamingo[all]
There are three `requirements.txt` files:
- `requirements.txt`
- `requirements-training.txt`
- `requirements-eval.txt`
Depending on your use case, you can install any of these with pip install -r <requirements-file.txt>. The base file contains only the dependencies needed for running the model.
Development
pip install pre-commit
or use brew for MacOS
brew install pre-commit
Check the version installed with
pre-commit - version
Then at the root of this repository, run
pre-commit install
Then every time we run git commit, the checks are run. If the files are reformatted by the hooks, run
git add
for your changed files and
git commit
again
Training Procedure
To train OpenFlamingo, please ensure your environment matches that of environment.yml.
Data Processing
The codebase uses WebDataset to efficiently load .tar files containing image and text sequences. We recommend resampling shards with replacement during training using the — dataset_resampled flag.
import json import os import tarfile def compress_directory_to_tar(directory_path): json_files = [f for f in os.listdir(directory_path) if f.endswith('.json')] os.makedirs('replicate_mmc4', exist_ok=True) for i in range(0, len(json_files), 20): batch_files = json_files[i:i+20] tar_file_path = os.path.join('replicate_mmc4', f"{i//20:09d}.tar") with tarfile.open(tar_file_path, "w:gz") as tar: for file in batch_files: tar.add(os.path.join(directory_path, file), arcname=file) print(f"Batch {i//20} compressed to {tar_file_path}") def convert_adni_to_mmc4(input_json_path, output_folder): # Ensure the output folder exists os.makedirs(output_folder, exist_ok=True) # Load the large JSON file with open(input_json_path, 'r') as f: data = json.load(f) matched_text_index = 0 # Iterate over each item in the list and save it as a separate JSON file for idx, item in enumerate(data): # Ensure compatibility with the structure of f9773b9c866145c28fe0b701dde8dfbe.json # Handle text list: conversations = item.get("conversations", None) if conversations is not None: text_list = [] for conversation in conversations: text_list.append(conversation["value"]) # Check for &amp;lt;image&amp;gt; tag in the first element of conversations list first_convo = conversations[0]["value"] if "&amp;lt;image&amp;gt;" in first_convo: if first_convo.startswith("&amp;lt;image&amp;gt;"): matched_text_index = 0 elif first_convo.endswith("&amp;lt;image&amp;gt;"): matched_text_index = 1 item["text_list"] = text_list # Handle image's base64 content: with open('./sample_base64.txt', 'r') as f: sample_img_base64_data = f.read() # Handle image info: img_info = [] images_list = item.get("image", None) if images_list is not None: for img in images_list: img_obj = {} img_obj["image_name"] = img img_obj["raw_url"] = "https://example.com/{}".format(img) img_obj["matched_text_index"] = matched_text_index img_obj["matched_sim"] = 0.75 img_obj["image_base64"] = sample_img_base64_data img_info.append(img_obj) # Create similarity_matrix similarity_matrix = [] for img in img_info: for _ in range(len(text_list)): inner_list = [0] * len(text_list) inner_list[matched_text_index] = 1 similarity_matrix.append(inner_list) # item["similarity_matrix"] = similarity_matrix output_item = { "id": item.get("id", None), "url": "https://example.com", "text_list": item.get("text_list", None), "image_info": img_info, "similarity_matrix": similarity_matrix, "could_have_url_duplicate": 0 } # Save the item as a separate JSON file output_path = os.path.join(output_folder, f"{idx:05d}.json") with open(output_path, 'w') as out_f: json.dump(output_item, out_f)
ChatGPT-generated sequences
A subset of our models (listed below) were also trained on experimental ChatGPT-generated (image, text) sequences, where images are pulled from LAION. The shards containing these sequences can be found at this CodaLab worksheet. They are unable to distribute raw images in the released shards; images must be pre-downloaded from the urls in the json files and converted to base64 before using this data for training in our codebase.
Models trained with ChatGPT-generated sequences:
Training Command
A sample Slurm is provided in the training script in scripts/. You can also modify the following command (which was specifically used in our case):
torchrun --nnodes=1 --nproc_per_node=8 open_flamingo/train/train.py \ --lm_path anas-awadalla/mpt-1b-redpajama-200b \ --tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \ --cross_attn_every_n_layers 1 \ --dataset_resampled \ --batch_size_mmc4 2 \ --train_num_samples_mmc4 1000 \ --workers=4 \ --run_name OpenFlamingo-3B-vitl-mpt1b \ --num_epochs 20 \ --warmup_steps 1875 \ --mmc4_textsim_threshold 0.24 \ --mmc4_shards "modifications/VLM_ADNI_DATA/replicate_mmc4/{000000000..000000040}.tar" \ --report_to_wandb
The MPT-1B base and instruct modeling code does not accept the `labels` kwarg or compute cross-entropy loss directly within `forward()`, as expected by our codebase. We suggest using a modified version of the MPT-1B models found here and here.
Distributed raining
By default, train.py uses Pytorch’s DistributedDataParallel for training.
To use FullyShardedDataParallel, use the — fsdp flag.
Some notes on FSDP from the OpenFlamingo team:
We recommend using the — fsdp_use_orig_params df flag. If — fsdp is on without this flag, all language model embeddings will be unfrozen during training. (In contrast, the default behavior is to only train the newly added <image> and <|endofchunk|> tokens.)
Note: We’ve encountered issues using OPT with this flag. Other language models should be compatible.
Our current FSDP wrapping strategy does not permit training language model embeddings that use tied weights (i.e., tied input/output embeddings). To train such models with FSDP, the language model embeddings must be frozen with the — freeze_lm_embeddings flag.
We also implement gradient checkpointing and mixed precision training. Use the — gradient_checkpointing and — precision arguments, respectively.
Initializing an OpenFlamingo model
OpenFlamingo supports pretrained vision encoders from the OpenCLIP package, which includes OpenAI’s pretrained models.
They also support pretrained language models from the transformers package, such as MPT, RedPajama, LLaMA, OPT, GPT-Neo, GPT-J, and Pythia models
from open_flamingo import create_model_and_transforms model, image_processor, tokenizer = create_model_and_transforms( clip_vision_encoder_path="ViT-L-14", clip_vision_encoder_pretrained="openai", lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b", tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b", cross_attn_every_n_layers=1, cache_dir="PATH/TO/CACHE/DIR" # Defaults to ~/.cache
Below is the results reported from our WandBs: NVIDIA H100 GPUs
NVIDIA H100 System that was employed:
Model’s reported metrics