Skip to content

Latest commit

 

History

History
91 lines (61 loc) · 4.41 KB

File metadata and controls

91 lines (61 loc) · 4.41 KB
title Image Segmentation
sidebar_label Image Segmentation
description Going beyond bounding boxes: How to classify every single pixel in an image.
tags
deep-learning
cnn
computer-vision
segmentation
u-net
mask-rcnn

While Image Classification tells us what is in an image, and Object Detection tells us where it is, Image Segmentation provides a pixel-perfect understanding of the scene.

It is the process of partitioning a digital image into multiple segments (sets of pixels) to simplify or change the representation of an image into something that is more meaningful and easier to analyze.

1. Types of Segmentation

Not all segmentation tasks are the same. We generally categorize them into three levels of complexity:

A. Semantic Segmentation

Every pixel is assigned a class label (e.g., "Road," "Sky," "Car"). However, it does not differentiate between multiple instances of the same class. Two cars parked next to each other will appear as a single connected "blob."

B. Instance Segmentation

This goes a step further by detecting and delineating each distinct object of interest. If there are five people in a photo, instance segmentation will give each person a unique color/ID.

C. Panoptic Segmentation

The "holy grail" of segmentation. It combines semantic and instance segmentation to provide a total understanding of the scene—identifying individual objects (cars, people) and background textures (sky, grass).

2. The Architecture: Encoder-Decoder (U-Net)

Traditional CNNs lose spatial resolution through pooling. To get back to an image output of the same size as the input, we use an Encoder-Decoder architecture.

  1. Encoder (The "What"): A standard CNN that downsamples the image to extract high-level features.
  2. Bottleneck: The compressed representation of the image.
  3. Decoder (The "Where"): Uses Transposed Convolutions (Upsampling) to recover the spatial dimensions.
  4. Skip Connections: These are the "secret sauce" of the U-Net architecture. They pass high-resolution information from the encoder directly to the decoder to help refine the boundaries of the mask.

3. Loss Functions for Segmentation

Because we are classifying every pixel, standard accuracy can be misleading (especially if 90% of the image is just background). We use specialized metrics:

  • Intersection over Union (IoU) / Jaccard Index: Measures the overlap between the predicted mask and the ground truth.
  • Dice Coefficient: Similar to IoU, it measures the similarity between two sets of data and is more robust to class imbalance.

$$ IoU = \frac{\text{Area of Overlap}}{\text{Area of Union}} $$

4. Real-World Applications

  • Medical Imaging: Identifying tumors or mapping organs in MRI and CT scans.
  • Self-Driving Cars: Identifying the exact boundaries of lanes, sidewalks, and drivable space.
  • Satellite Imagery: Mapping land use, deforestation, or urban development.
  • Portrait Mode: Separating the person (subject) from the background to apply a "bokeh" blur effect.

5. Popular Models

Model Type Best For
U-Net Semantic Medical imaging and biomedical research.
Mask R-CNN Instance Detecting objects and generating masks (e.g., counting individual cells).
DeepLabV3+ Semantic State-of-the-art results using Atrous (Dilated) Convolutions.
SegNet Semantic Efficient scene understanding for autonomous driving.

6. Implementation Sketch (PyTorch)

Using a pre-trained segmentation model from torchvision:

import torch
from torchvision import models

# Load a pre-trained DeepLabV3 model
model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()

# Input: (Batch, Channels, Height, Width)
dummy_input = torch.randn(1, 3, 224, 224)

# Output: Returns a dictionary containing 'out' - the pixel-wise class predictions
with torch.no_grad():
    output = model(dummy_input)['out']

print(f"Output shape: {output.shape}") 
# Shape will be [1, 21, 224, 224] (for 21 Pascal VOC classes)

References


Segmentation provides a high level of detail, but it's computationally expensive. How do we make these models faster for real-time applications?