How to handle multi-modal data with PyTorch?

How to handle multi-modal data with PyTorch?

Artificial intelligence has ushered in a new era of voluminous and varied data. Multi-modal data, which encompasses different forms of data like text, images, audio, and video, is becoming increasingly common in AI applications. Handling such diverse data types is a challenge that requires sophisticated approaches to integrate and process information effectively.

This article will explore multi-modal data handling using PyTorch, a leading deep-learning library. We will explore the concepts, techniques, and tools available within PyTorch and its ecosystem, specifically focusing on the TorchMultimodal library and the WideDeep PyTorch framework.

The goal is to provide a comprehensive guide on building and fine-tuning models that can efficiently work with multi-modal data. This includes understanding the fundamental concepts, preparing datasets, implementing models, and optimizing performance. Whether you're a researcher, a practitioner, or someone with a keen interest in multi-modal AI, this guide aims to equip you with the knowledge to push the boundaries of what's possible with AI.

Understanding Multi-Modal Data in PyTorch

Definition of Multi-Modal Data:

Multi-modal data refers to data that comes from different sources or formats. For instance, a single post on a social media platform may contain text, images, and possibly audio or video. Integrating these diverse data types to improve the accuracy of predictive models is at the core of multi-modal data analysis.

The Role of PyTorch in Multi-Modal Data Handling:

PyTorch, an open-source machine learning library, has become a go-to tool for developing deep learning models due to its flexibility and ease of use. With its dynamic computation graph and strong GPU acceleration, PyTorch facilitates the building and training complex models that can learn from multi-modal data.

PyTorch offers a comprehensive ecosystem to handle multi-modal data, including libraries like TorchVision and TorchText, which provide datasets and models for vision and text-related tasks. The introduction of TorchMultimodal extends this ecosystem by focusing on the intersection of different modalities, enabling models that can process and relate information from text, images, videos, and other sensory inputs in a unified framework.

Key Concepts of Multi-Modal Learning

Introduction to TorchMultimodal:

TorchMultimodal is a domain library within the PyTorch ecosystem designed to streamline the development of state-of-the-art (SoTA) multi-task, multi-modal models. It serves as a modular framework providing pre-built blocks like fusion layers, datasets, and utility functions that facilitate the integration of multi-modal data sources.

Wide and Deep Learning: Combining Structured and Unstructured Data:

The 'Wide and Deep' learning paradigm is a key approach in multi-modal learning. It combines 'wide' linear models that handle structured, tabular data with 'deep' neural networks that process unstructured data such as text and images. This combination allows for memorizing feature interactions and generalization through feature learning.

Pretrained Models in Multi-Modal Learning:

TorchMultimodal offers access to pre-trained models such as FLAVA, MDETR, and Omnivore. These models are trained on large datasets and are capable of understanding and processing multiple types of input data, making them highly versatile for a variety of tasks:

  • FLAVA: A model that utilizes transformer-based encoders for text and image inputs, pre-trained using a combination of contrastive learning and masking strategies.

  • MDETR: A multi-modal transformer that excels at object detection by correlating regions in an image with textual queries.

  • Omnivore: A model that showcases the ability to handle various data types, including images, videos, and 3D data, through a single architecture.

These models provide an excellent starting point for custom tasks, allowing researchers and developers to fine-tune domain-specific datasets, saving valuable time and computational resources.

Preparing Your Data for Multi-Modal Learning

Data Preprocessing: Image and Text:

Proper data preprocessing is essential for multi-modal learning. For image data, common preprocessing steps include resizing, normalization, and augmentation, which can be implemented using TorchVision's transforms. Text data, on the other hand, requires tokenization, padding, and possibly embedding, which can be accomplished with libraries like TorchText or Hugging Face's Transformers.

Building a Dataset for Multi-Modal Learning:

Creating a dataset for multi-modal learning involves defining a structure that encapsulates the different data types. For example, a dataset class in PyTorch should be able to return a single item as a dictionary with keys corresponding to other modalities, such as:

    'text': tensor representing tokenized text,
    'image': tensor representing transformed image,
    'label': tensor representing the target output

To implement such a dataset, one could extend the class and override the __getitem__ method to include the processing of each modality.

Here is a sample snippet for a custom multi-modal dataset class:

from import Dataset
from torchvision import transforms as T
from transformers import BertTokenizer

class MyMultimodalDataset(Dataset):
    def __init__(self, images, texts, labels):
        self.images = images
        self.texts = texts
        self.labels = labels
        self.image_transform = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor()])
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.image_transform(self.images[idx])
        text = self.tokenizer(self.texts[idx], padding='max_length', max_length=512, return_tensors='pt').input_ids.squeeze(0)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return {'image': image, 'text': text, 'label': label}

This code snippet transforms images into tensors suitable for a CNN, while texts are tokenized into a format ideal for a transformer model.

Implementing Multi-Modal Models in PyTorch

Overview of PyTorch Model Components:

When building multi-modal models in PyTorch, you generally deal with several sub-models, each designed to handle a specific data type. For example, a Convolutional Neural Network (CNN) for images and a Transformer for text. The output of these sub-models is then combined, often in a fusion layer, which can be as simple as concatenation followed by fully connected layers or more complex like a cross-attention mechanism.

Defining the Model Architecture:

The architecture of a multi-modal model in PyTorch typically includes:

  • Unimodal Encoders: Separate encoders for each modality (e.g., ResNet for images, BERT for text).

  • Fusion Layer: A component that combines the outputs of unimodal encoders. This can involve element-wise addition, multiplication, concatenation, or even a learnable mixer.

  • Classifier/Regressor Head: The final part of the model, which makes predictions based on fused representations.

Here's a conceptual snippet for a simple multi-modal model architecture in PyTorch:

import torch
from torch import nn
from torchvision.models import resnet50
from transformers import BertModel

class SimpleMultimodalModel(nn.Module):
    def __init__(self, num_classes):
        self.image_encoder = resnet50(pretrained=True)
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.fusion_layer = nn.Linear(self.image_encoder.fc.out_features + self.text_encoder.config.hidden_size, 512)
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, images, input_ids, attention_mask):
        image_features = self.image_encoder(images)
        text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        combined =[image_features, text_features], dim=1)
        fused = torch.relu(self.fusion_layer(combined))
        logits = self.classifier(fused)
        return logits

Model Training and Fine-Tuning: A Step-by-Step Guide:

Training a multi-modal model involves:

  1. Loading the datasets.

  2. Defining the model, loss function, and optimizer.

  3. Running the training loop, which includes forward and backward passes.

  4. Evaluating the model on a validation set.

  5. Fine-tuning, if using pre-trained components, to adapt the model to specific tasks.

An example training loop snippet could be:

# Assuming the existence of DataLoader instances for training and validation: train_loader, val_loader
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for batch in train_loader:
        images, input_ids, attention_mask, labels = batch['image'], batch['text'], batch['attention_mask'], batch['label']
        outputs = model(images, input_ids, attention_mask)
        loss = criterion(outputs, labels)

    # Validation step
    with torch.no_grad():
        val_loss = 0
        for batch in val_loader:
            images, input_ids, attention_mask, labels = batch['image'], batch['text'], batch['attention_mask'], batch['label']
            outputs = model(images, input_ids, attention_mask)
            val_loss += criterion(outputs, labels).item()
        val_loss /= len(val_loader)
    print(f"Epoch {epoch}, Validation Loss: {val_loss}")

Fine-Tuning with FLAVA: A Case Study

Setting up the FLAVA Model for Fine-Tuning:

Fine-tuning a pre-trained model involves slightly adjusting its parameters to adapt to a new, often related task. With FLAVA, we're leveraging its pre-trained capabilities and adapting it for a visual question-answering (VQA) task.

Here's how to set up FLAVA for fine-tuning:

from torchmultimodal.models.flava.model import flava_model_for_classification

# Assume num_classes is the number of possible answers in the VQA task
flava_model = flava_model_for_classification(num_classes=num_classes)

Data Loading and Transformation for VQA:

The data for VQA consists of images and their corresponding questions. The FLAVA model requires these inputs to be transformed into a specific format:

from torchvision.transforms import Compose, Resize, ToTensor
from transformers import BertTokenizer
from functools import partial

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def transform(input):
    image_transform = Compose([Resize(256), CenterCrop(224), ToTensor()])
    image = image_transform(input['image'])
    inputs_text = tokenizer(input['question'], padding='max_length', truncation=True, max_length=128, return_tensors='pt')
    inputs_text = {key: val.squeeze(0) for key, val in inputs_text.items()}  # Remove batch dimension
    return {'image': image, 'input_ids': inputs_text['input_ids'], 'attention_mask': inputs_text['attention_mask']}

Training the FLAVA Model for VQA:

Training involves passing the transformed data through the model and optimizing the parameters using a loss function.

optimizer = torch.optim.AdamW(flava_model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for idx, batch in enumerate(train_dataloader):
        images, questions, attention_masks = batch['image'], batch['input_ids'], batch['attention_mask']
        outputs = flava_model(images=images, text=questions, text_mask=attention_masks)
        loss = criterion(outputs.logits, batch['answers'])
        if idx % log_interval == 0:
            print(f"Epoch {epoch} [{idx}/{len(train_dataloader)}], Loss: {loss.item()}")

Code Snippet for FLAVA Fine-Tuning:

The provided code snippets are part of the setup and training process for fine-tuning the FLAVA model on a VQA task. The exact implementation may vary depending on the specifics of the task and dataset.

Expected Output Analysis:

When training, you should expect the loss to gradually decrease as the model learns to answer questions based on the images and associated text. Plotting the training loss over time can give you insights into the learning process and whether the model is improving.

Training Loss over Epochs

Here's a visual representation of the hypothetical training loss over epochs. As expected, the loss decreases, indicating that the model is learning from the data over time. Such a plot is instrumental in monitoring the training process, identifying patterns, and making informed decisions about adjustments to the training regimen or hyperparameters.

Using WideDeep PyTorch Framework for Multi-Modal Data

Introduction to WideDeep:

The WideDeep PyTorch framework is designed to handle complex multi-modal data scenarios efficiently. It combines deep learning's capability to learn abstract representations (the "Deep" part) with linear models' ability to memorize sparse feature interactions (the "Wide" part). This combination is powerful for tasks involving structured tabular data and unstructured data like text and images.

Building Multi-Modal Models with WideDeep:

WideDeep facilitates building models that can process and integrate different data types seamlessly. It supports various deep learning components for text (deeptext) and images (deepimage) alongside tabular data (deeptabular).

Multi-Modal Models with WideDeep

A typical WideDeep model might be structured as follows:

  • Wide component: A linear model that uses one-hot encoded features to capture the relationships between categorical variables.

  • Deep component: Comprises multiple deep learning models tailored to specific data types, such as a CNN for image data and a transformer for text data.

Here's an example setup using WideDeep:

from pytorch_widedeep.models import WideDeep
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
from pytorch_widedeep.models import Wide, TabMlp, DeepImage, DeepText
from pytorch_widedeep.metrics import Accuracy

# Assuming `df` is a DataFrame with your data
wide_preprocessor = WidePreprocessor(wide_cols=[list of wide columns])
X_wide = wide_preprocessor.fit_transform(df)

tab_preprocessor = TabPreprocessor(embed_cols=[list of categorical columns], continuous_cols=[list of continuous columns])
X_tab = tab_preprocessor.fit_transform(df)

# Model setup
wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)
deeptabular = TabMlp(mlp_hidden_dims=[64, 32], column_idx=tab_preprocessor.column_idx, embed_input=tab_preprocessor.embeddings_input)
# Assume deepimage and deeptext are defined similarly

# Combining components into a WideDeep model
model = WideDeep(wide=wide, deeptabular=deeptabular, deepimage=deepimage, deeptext=deeptext)

Preprocessing data with WideDeep:

Preprocessing involves transforming raw data into a suitable format for each model component. WideDeep offers a comprehensive preprocessing module that simplifies this process for wide, deep tabular, text, and image data.

Training and Evaluating WideDeep Models:

WideDeep models are trained similarly to any PyTorch model, focusing on effectively handling the distinct data types. The framework provides a Trainer class that abstracts the training loop, making it easier to train and evaluate your models.

from import Trainer

trainer = Trainer(model=model, objective="binary", metrics=[Accuracy]), X_tab=X_tab, target=target, n_epochs=10, batch_size=256, val_split=0.1)

Explaining Model Predictions with WideDeep:

Understanding why your model makes specific predictions is crucial. WideDeep allows for integrating explainability tools like SHAP or Captum to interpret the model's decisions, notably how different features influence the model's predictions.

Advanced Techniques for Scaling and Optimization

Efficiently scaling and optimizing multi-modal models is crucial for handling large datasets' complexities and computational demands. PyTorch offers several advanced techniques for this purpose:

Scaling Models with PyTorch Distributed:

PyTorch's Distributed package (`torch.distributed`) enables parallelism across multiple processes and machines, facilitating the scaling of model training to larger datasets and faster execution times. This is particularly useful for multi-modal models, which can be computationally intensive.

Key components include:

  • Data Parallelism: Distributes batches across multiple GPUs, allowing larger batch sizes and faster processing.

  • Model Parallelism: Splits a single model across multiple GPUs, which is proper for large models that don't fit into a single GPU's memory.

  • Distributed Data Parallel (DDP): Combines data and model parallelism, synchronizing node gradients for efficient multi-GPU training.

Using Activation Checkpointing:

Activation checkpointing saves memory by trading compute for memory. It works by recomputing activations on the fly during the backward pass rather than storing them during the forward pass. This is beneficial for training deep and memory-intensive models.

PyTorch's torch.utils.checkpoint module can be easily integrated into your models to utilize this technique. Here's a basic example:

from torch.utils.checkpoint import checkpoint

class MyModel(nn.Module):
    def __init__(self, submodules):
        super(MyModel, self).__init__()
        self.submodules = submodules

    def forward(self, x):
        # Use checkpointing for each submodule
        for submodule in self.submodules:
            x = checkpoint(submodule, x)
        return x

Mixed Precision Training:

Mixed precision training utilizes 16-bit and 32-bit floating-point types to reduce memory usage and improve performance. PyTorch's Automatic Mixed Precision (AMP) package makes it straightforward to apply this optimization.

Example usage:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for data in dataloader:
    with autocast():
        output = model(data)
        loss = loss_fn(output, target)

Efficient Data Loading:

Optimizing the data pipeline is crucial for scaling multi-modal models. PyTorch's DataLoader allows asynchronous data loading and preprocessing using multiple workers, significantly reducing the time models wait for data.

These techniques, when combined with the foundational practices discussed earlier, equip you with the tools needed to develop, scale, and optimize multi-modal models in PyTorch effectively.


This article has explored the nuances of handling multi-modal data in PyTorch, from leveraging TorchMultimodal and WideDeep frameworks to advanced techniques for scaling and optimization. As multi-modal AI evolves, mastering these concepts and tools will be invaluable for developing sophisticated models to interpret and integrate diverse data types.


As multi-modal AI advances, staying updated with the latest research and tools will be crucial for pushing the boundaries of what's possible.

Did you find this article valuable?

Support Anay Sinhal by becoming a sponsor. Any amount is appreciated!