DINOv2 for Custom Dataset Segmentation: A Comprehensive Tutorial.

Discover DINOv2, a powerful self-supervised vision transformer trained on 142M images. This tutorial guides you through data loading, preprocessing, model definition, and training using PyTorch.

DINOv2 for Custom Dataset Segmentation: A Comprehensive Tutorial.
Figure 1: The model is completely working in this tutorial. (output image when the model is trained for just one epoch) (image by author)

After YOLOv8 and SAM (Segment Anything Model), the most anticipated computer vision model is DINOv2. I got the motivation for this tutorial from this GitHub repository: https://github.com/NielsRogge/Transformers-Tutorials/tree/master, while running the code, I found 2 bugs because of that, I got some annoying errors while training the model (in his tutorial, he stopped the training the process after some steps and error arises in between and at last training step). The entire code is taken from his notebook (except for some changes :) ), and here is the plan of attack:


Plan of Attack

  1. Introduction of DINOv2
  2. Library installation
  3. Load dataset
  4. Create PyTorch dataset
  5. Create PyTorch dataloaders
  6. Define model
  7. Train the model

Introduction of DINOv2

DINOv2 is a vision transformer that has been trained in a self-supervised manner on a meticulously curated dataset of 142 million images. It offers the best image features, or embeddings, available for downstream tasks such as image classification, image segmentation, and depth estimation.

Figure 1 conceptualizes this approach, in this tutorial, I am simply training a linear transformation (1*1 CNN layer) on top of a frozen DINOv2 backbone. This transformation will map the features (patch embeddings) to logits (the unnormalized scores output by neural networks, indicative of the model’s predictions). In the context of semantic segmentation, the logits will take the shape of (batch_size, num_classes, height, and width), corresponding to a predicted class for each pixel.

Library installation

Here are two main libraries:

!pip install -q git+https://github.com/huggingface/transformers.git datasets 
 
!pip install -q evaluate

Load dataset

Next, let’s load an image segmentation dataset. In this case, we’ll use the Foodseg dataset.

from datasets import load_dataset 
 
#dataset 
dataset = load_dataset("EduardoPacheco/FoodSeg103") 
 
#lables 
id2label = { 
    0: "background", 
    1: "candy", 
    2: "egg tart", 
    3: "french fries", 
    4: "chocolate", 
    5: "biscuit", 
    6: "popcorn", 
    7: "pudding", 
    8: "ice cream", 
    9: "cheese butter", 
    10: "cake", 
    11: "wine", 
    12: "milkshake", 
    13: "coffee", 
    14: "juice", 
    15: "milk", 
    16: "tea", 
    17: "almond", 
    18: "red beans", 
    19: "cashew", 
    20: "dried cranberries", 
    21: "soy", 
    22: "walnut", 
    23: "peanut", 
    24: "egg", 
    25: "apple", 
    26: "date", 
    27: "apricot", 
    28: "avocado", 
    29: "banana", 
    30: "strawberry", 
    31: "cherry", 
    32: "blueberry", 
    33: "raspberry", 
    34: "mango", 
    35: "olives", 
    36: "peach", 
    37: "lemon", 
    38: "pear", 
    39: "fig", 
    40: "pineapple", 
    41: "grape", 
    42: "kiwi", 
    43: "melon", 
    44: "orange", 
    45: "watermelon", 
    46: "steak", 
    47: "pork", 
    48: "chicken duck", 
    49: "sausage", 
    50: "fried meat", 
    51: "lamb", 
    52: "sauce", 
    53: "crab", 
    54: "fish", 
    55: "shellfish", 
    56: "shrimp", 
    57: "soup", 
    58: "bread", 
    59: "corn", 
    60: "hamburg", 
    61: "pizza", 
    62: "hanamaki baozi", 
    63: "wonton dumplings", 
    64: "pasta", 
    65: "noodles", 
    66: "rice", 
    67: "pie", 
    68: "tofu", 
    69: "eggplant", 
    70: "potato", 
    71: "garlic", 
    72: "cauliflower", 
    73: "tomato", 
    74: "kelp", 
    75: "seaweed", 
    76: "spring onion", 
    77: "rape", 
    78: "ginger", 
    79: "okra", 
    80: "lettuce", 
    81: "pumpkin", 
    82: "cucumber", 
    83: "white radish", 
    84: "carrot", 
    85: "asparagus", 
    86: "bamboo shoots", 
    87: "broccoli", 
    88: "celery stick", 
    89: "cilantro mint", 
    90: "snow peas", 
    91: "cabbage", 
    92: "bean sprouts", 
    93: "onion", 
    94: "pepper", 
    95: "green beans", 
    96: "French beans", 
    97: "king oyster mushroom", 
    98: "shiitake", 
    99: "enoki mushroom", 
    100: "oyster mushroom", 
    101: "white button mushroom", 
    102: "salad", 
    103: "other ingredients" 
} 
 
# visualize the images and masks 
import numpy as np 
import matplotlib.pyplot as plt 
 
# map every class to a random color 
id2color = {k: list(np.random.choice(range(256), size=3)) for k,v in id2label.items()} 
 
def visualize_map(image, segmentation_map): 
    color_seg = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3 
    for label, color in id2color.items(): 
        color_seg[segmentation_map == label, :] = color 
 
    # Show image + mask 
    img = np.array(image) * 0.5 + color_seg * 0.5 
    img = img.astype(np.uint8) 
 
    plt.figure(figsize=(15, 10)) 
    plt.imshow(img) 
    plt.show() 
 
visualize_map(image, segmentation_map)

Create PyTorch dataset

To prepare examples for the model, we create a standard PyTorch dataset that includes image augmentations. We randomly resize and crop the training images to a uniform resolution of 448x448 pixels and normalize the color channels, ensuring all training images are of the same fixed resolution. For this process, we employ the Albumentations library, although it’s worth noting that other libraries, such as Torchvision or Kornia, can also serve this purpose.

It’s important to remember that the model expects input pixel_values with the shape (batch_size, num_channels, height, width). Since Albumentations operates on NumPy arrays, which use a channels-last format, we need to reorder the dimensions to place the channels first. In addition, the model requires labels in the shape of (batch_size, height, and width), which provide the ground truth label for each pixel in every example of the batch.

from torch.utils.data import Dataset 
import torch 
 
class SegmentationDataset(Dataset): 
  def __init__(self, dataset, transform): 
    self.dataset = dataset 
    self.transform = transform 
 
  def __len__(self): 
    return len(self.dataset) 
 
  def __getitem__(self, idx): 
    item = self.dataset[idx] 
    original_image = np.array(item["image"]) 
    original_segmentation_map = np.array(item["label"]) 
 
    transformed = self.transform(image=original_image, mask=original_segmentation_map) 
    image, target = torch.tensor(transformed['image']), torch.LongTensor(transformed['mask']) 
 
    # convert to C, H, W 
    image = image.permute(2,0,1) 
 
    return image, target, original_image, original_segmentation_map 
 
 
# Let's create the training and validation datasets (note that we only randomly crop for training images). 
 
import albumentations as A 
 
ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255 
ADE_STD = np.array([58.395, 57.120, 57.375]) / 255 
 
train_transform = A.Compose([ 
    # hadded an issue with an image being too small to crop, PadIfNeeded didn't help... 
    # if anyone knows why this is happening I'm happy to read why 
    # A.PadIfNeeded(min_height=448, min_width=448), 
    # A.RandomResizedCrop(height=448, width=448), 
    A.Resize(width=448, height=448), 
    A.HorizontalFlip(p=0.5), 
    A.Normalize(mean=ADE_MEAN, std=ADE_STD), 
], is_check_shapes=False) 
 
val_transform = A.Compose([ 
    A.Resize(width=448, height=448), 
    A.Normalize(mean=ADE_MEAN, std=ADE_STD), 
 
], is_check_shapes=False) 
 
train_dataset = SegmentationDataset(dataset["train"], transform=train_transform) 
val_dataset = SegmentationDataset(dataset["validation"], transform=val_transform) 
 
pixel_values, target, original_image, original_segmentation_map = train_dataset[3] 
print(pixel_values.shape) 
print(target.shape)

Create PyTorch dataloaders

Next, we create PyTorch dataloaders, which allow us to get batches of data (as neural networks are trained on batches using stochastic gradient descent or SGD). We just stack the various images and labels along a new batch dimension.

from torch.utils.data import DataLoader 
 
def collate_fn(inputs): 
    batch = dict() 
    batch["pixel_values"] = torch.stack([i[0] for i in inputs], dim=0) 
    batch["labels"] = torch.stack([i[1] for i in inputs], dim=0) 
    batch["original_images"] = [i[2] for i in inputs] 
    batch["original_segmentation_maps"] = [i[3] for i in inputs] 
 
    return batch 
 
train_dataloader = DataLoader(train_dataset, batch_size=3, shuffle=True, collate_fn=collate_fn) 
val_dataloader = DataLoader(val_dataset, batch_size=3, shuffle=False, collate_fn=collate_fn) 
 
batch = next(iter(train_dataloader)) 
for k,v in batch.items(): 
  if isinstance(v,torch.Tensor): 
    print(k,v.shape)

Define model

Next, we define the model, which comprises DINOv2 as the backbone, along with a linear classifier on top. DINOv2 is a standard vision transformer, and thus, it produces “patch embeddings,” which means an embedding vector for each image patch. Given that we use an image resolution of 448 pixels and a DINOv2 model with a patch resolution of 14, as shown here, we obtain (448/14)² = 1024 patches. Consequently, the model outputs a tensor with the shape (batch_size, number of patches, hidden_size), or (batch_size, 1024, 768), for a batch of images (the model features a hidden size — or embedding dimension — of 768, as indicated here).

Subsequently, we reshape this tensor to (batch_size, 32, 32, 768). Following this, we apply the linear layer (implemented here as a Conv2D layer, which acts as a linear transformation when using a kernel size of 1x1). This Conv2D layer transforms the patch embeddings into a logit tensor of shape (batch_size, num_labels, height, width), which is requisite for semantic segmentation. This tensor contains the scores predicted by the model for all the classes, for each pixel, for every example in the batch.

import torch 
from transformers import Dinov2Model, Dinov2PreTrainedModel 
from transformers.modeling_outputs import SemanticSegmenterOutput 
 
class LinearClassifier(torch.nn.Module): 
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=1): 
        super(LinearClassifier, self).__init__() 
 
        self.in_channels = in_channels 
        self.width = tokenW 
        self.height = tokenH 
        self.classifier = torch.nn.Conv2d(in_channels, num_labels, (1,1)) 
 
    def forward(self, embeddings): 
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels) 
        embeddings = embeddings.permute(0,3,1,2) 
 
        return self.classifier(embeddings) 
 
 
class Dinov2ForSemanticSegmentation(Dinov2PreTrainedModel): 
  def __init__(self, config): 
    super().__init__(config) 
 
    self.dinov2 = Dinov2Model(config) 
    self.classifier = LinearClassifier(config.hidden_size, 32, 32, config.num_labels) 
 
  def forward(self, pixel_values, output_hidden_states=False, output_attentions=False, labels=None): 
    # use frozen features 
    outputs = self.dinov2(pixel_values, 
                            output_hidden_states=output_hidden_states, 
                            output_attentions=output_attentions) 
    # get the patch embeddings - so we exclude the CLS token 
    patch_embeddings = outputs.last_hidden_state[:,1:,:] 
 
    # convert to logits and upsample to the size of the pixel values 
    logits = self.classifier(patch_embeddings) 
    logits = torch.nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False) 
 
    loss = None 
    if labels is not None: 
      # important: we're going to use 0 here as ignore index instead of the default -100 
      # as we don't want the model to learn to predict background 
      loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0) 
      loss = loss_fct(logits.squeeze(), labels.squeeze()) 
 
    return SemanticSegmenterOutput( 
        loss=loss, 
        logits=logits, 
        hidden_states=outputs.hidden_states, 
        attentions=outputs.attentions, 
    ) 
 
#We can instantiate the model as follows: 
 
model = Dinov2ForSemanticSegmentation.from_pretrained("facebook/dinov2-base", id2label=id2label, num_labels=len(id2label)) 
 
#Important: we don't want to train the DINOv2 backbone, only the linear classification head. Hence we don't want to track any gradients for the backbone parameters. This will greatly save us in terms of memory used: 
 
for name, param in model.named_parameters(): 
  if name.startswith("dinov2"): 
    param.requires_grad = False 
 
#Let's perform a forward pass on a random batch, to verify the shape of the logits, verify we can calculate a loss: 
 
outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"]) 
print(outputs.logits.shape) 
print(outputs.loss) 
 
 
import evaluate 
metric = evaluate.load("mean_iou")

Train the model

Now let's train the model for one epoch:

from torch.optim import AdamW 
from tqdm.auto import tqdm 
 
# training hyperparameters 
# NOTE: I've just put some random ones here, not optimized at all 
# feel free to experiment, see also DINOv2 paper 
learning_rate = 5e-5 
epochs = 1 
 
optimizer = AdamW(model.parameters(), lr=learning_rate) 
 
# put model on GPU (set runtime to GPU in Google Colab) 
device = "cuda" if torch.cuda.is_available() else "cpu" 
model.to(device) 
 
# put model in training mode 
model.train() 
 
for epoch in range(epochs): 
  print("Epoch:", epoch) 
  for idx, batch in enumerate(tqdm(train_dataloader)): 
      pixel_values = batch["pixel_values"].to(device) 
      labels = batch["labels"].to(device) 
 
      # forward pass 
      outputs = model(pixel_values, labels=labels) 
      loss = outputs.loss 
 
      loss.backward() 
      optimizer.step() 
 
      # zero the parameter gradients 
      optimizer.zero_grad() 
 
      # evaluate 
      with torch.no_grad(): 
        predicted = outputs.logits.argmax(dim=1) 
 
        # note that the metric expects predictions + labels as numpy arrays 
        metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy()) 
 
      # let's print loss and metrics every 100 batches 
      if idx % 100 == 0: 
        metrics = metric.compute(num_labels=len(id2label), 
                                ignore_index=0, 
                                reduce_labels=False, 
        ) 
 
        print("Loss:", loss.item()) 
        print("Mean_iou:", metrics["mean_iou"]) 
        print("Mean accuracy:", metrics["mean_accuracy"])

Results

Here are the results I have only trained the model for one epoch, and we can see the mean IOU has already reached 0.37 (as shown in Result Figure 1)

Result figure 1: IOU and loss matrices plot (image by author)

And here are randomly selected results, shown in Result figure 2.

Result figure 2: Segmentation results on test dataset after 1 epoch (image by author)

Done 😃😃😃😃😃

Please feel free to take a look at the updated colab Notebook link: https://colab.research.google.com/drive/1UMQj7F_x0fSy_gevlTZ9zLYn7b02kTqi?usp=sharing

References

DINOv2: Learning Robust Visual Features without Supervision: https://arxiv.org/abs/2304.07193

If you like this work, then please share it with your friends like and follow me here on medium