Train a Segmentation Model¶
This segmentation task example will be divided into the following steps:
Note
You can also experience the notebook example here.
Download Camvid Dataset¶
First, you should download the Camvid dataset from OpenDataLab:
# https://opendatalab.com/CamVid
# Configure install
pip install opendatalab
# Upgraded version
pip install -U opendatalab
# Login
odl login
# Download this dataset
mkdir data
odl get CamVid -d data
# Preprocess data in Linux. You should extract the files to data manually in
# Windows
tar -xzvf data/CamVid/raw/CamVid.tar.gz.00 -C ./data
Implement the Camvid Dataset¶
We have implemented the CamVid class here, which inherits from VisionDataset. Within this class, we have overridden the __getitem__
and __len__
methods to ensure that each index returns a dict of images and labels. Additionally, we have implemented the color_to_class dictionary to map the mask’s color to the class index.
import os
import numpy as np
from torchvision.datasets import VisionDataset
from PIL import Image
import csv
def create_palette(csv_filepath):
color_to_class = {}
with open(csv_filepath, newline='') as csvfile:
reader = csv.DictReader(csvfile)
for idx, row in enumerate(reader):
r, g, b = int(row['r']), int(row['g']), int(row['b'])
color_to_class[(r, g, b)] = idx
return color_to_class
class CamVid(VisionDataset):
def __init__(self,
root,
img_folder,
mask_folder,
transform=None,
target_transform=None):
super().__init__(
root, transform=transform, target_transform=target_transform)
self.img_folder = img_folder
self.mask_folder = mask_folder
self.images = list(
sorted(os.listdir(os.path.join(self.root, img_folder))))
self.masks = list(
sorted(os.listdir(os.path.join(self.root, mask_folder))))
self.color_to_class = create_palette(
os.path.join(self.root, 'class_dict.csv'))
def __getitem__(self, index):
img_path = os.path.join(self.root, self.img_folder, self.images[index])
mask_path = os.path.join(self.root, self.mask_folder,
self.masks[index])
img = Image.open(img_path).convert('RGB')
mask = Image.open(mask_path).convert('RGB') # Convert to RGB
if self.transform is not None:
img = self.transform(img)
# Convert the RGB values to class indices
mask = np.array(mask)
mask = mask[:, :, 0] * 65536 + mask[:, :, 1] * 256 + mask[:, :, 2]
labels = np.zeros_like(mask, dtype=np.int64)
for color, class_index in self.color_to_class.items():
rgb = color[0] * 65536 + color[1] * 256 + color[2]
labels[mask == rgb] = class_index
if self.target_transform is not None:
labels = self.target_transform(labels)
data_samples = dict(
labels=labels, img_path=img_path, mask_path=mask_path)
return img, data_samples
def __len__(self):
return len(self.images)
We utilize the Camvid dataset to create the train_dataloader
and val_dataloader
, which serve as the data loaders for training and validation in the subsequent Runner.
import torch
import torchvision.transforms as transforms
norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(**norm_cfg)])
target_transform = transforms.Lambda(
lambda x: torch.tensor(np.array(x), dtype=torch.long))
train_set = CamVid(
'data/CamVid',
img_folder='train',
mask_folder='train_labels',
transform=transform,
target_transform=target_transform)
valid_set = CamVid(
'data/CamVid',
img_folder='val',
mask_folder='val_labels',
transform=transform,
target_transform=target_transform)
train_dataloader = dict(
batch_size=3,
dataset=train_set,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'))
val_dataloader = dict(
batch_size=3,
dataset=valid_set,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=dict(type='default_collate'))
Implement the Segmentation Model¶
The provided code defines a model class named MMDeeplabV3
. This class is derived from BaseModel
and incorporates the segmentation model of the DeepLabV3 architecture. It overrides the forward
method to handle both input images and labels and supports computing losses and returning predictions in both training and prediction modes.
For additional information about BaseModel
, you can refer to the Model tutorial.
from mmengine.model import BaseModel
from torchvision.models.segmentation import deeplabv3_resnet50
import torch.nn.functional as F
class MMDeeplabV3(BaseModel):
def __init__(self, num_classes):
super().__init__()
self.deeplab = deeplabv3_resnet50(num_classes=num_classes)
def forward(self, imgs, data_samples=None, mode='tensor'):
x = self.deeplab(imgs)['out']
if mode == 'loss':
return {'loss': F.cross_entropy(x, data_samples['labels'])}
elif mode == 'predict':
return x, data_samples
Training with Runner¶
Before training with the Runner, we need to implement the IoU (Intersection over Union) metric to evaluate the model’s performance.
from mmengine.evaluator import BaseMetric
class IoU(BaseMetric):
def process(self, data_batch, data_samples):
preds, labels = data_samples[0], data_samples[1]['labels']
preds = torch.argmax(preds, dim=1)
intersect = (labels == preds).sum()
union = (torch.logical_or(preds, labels)).sum()
iou = (intersect / union).cpu()
self.results.append(
dict(batch_size=len(labels), iou=iou * len(labels)))
def compute_metrics(self, results):
total_iou = sum(result['iou'] for result in self.results)
num_samples = sum(result['batch_size'] for result in self.results)
return dict(iou=total_iou / num_samples)
Implementing a visualization hook is also important to facilitate easier comparison between predictions and labels.
from mmengine.hooks import Hook
import shutil
import cv2
import os.path as osp
class SegVisHook(Hook):
def __init__(self, data_root, vis_num=1) -> None:
super().__init__()
self.vis_num = vis_num
self.palette = create_palette(osp.join(data_root, 'class_dict.csv'))
def after_val_iter(self,
runner,
batch_idx: int,
data_batch=None,
outputs=None) -> None:
if batch_idx > self.vis_num:
return
preds, data_samples = outputs
img_paths = data_samples['img_path']
mask_paths = data_samples['mask_path']
_, C, H, W = preds.shape
preds = torch.argmax(preds, dim=1)
for idx, (pred, img_path,
mask_path) in enumerate(zip(preds, img_paths, mask_paths)):
pred_mask = np.zeros((H, W, 3), dtype=np.uint8)
runner.visualizer.set_image(pred_mask)
for color, class_id in self.palette.items():
runner.visualizer.draw_binary_masks(
pred == class_id,
colors=[color],
alphas=1.0,
)
# Convert RGB to BGR
pred_mask = runner.visualizer.get_image()[..., ::-1]
saved_dir = osp.join(runner.log_dir, 'vis_data', str(idx))
os.makedirs(saved_dir, exist_ok=True)
shutil.copyfile(img_path,
osp.join(saved_dir, osp.basename(img_path)))
shutil.copyfile(mask_path,
osp.join(saved_dir, osp.basename(mask_path)))
cv2.imwrite(
osp.join(saved_dir, f'pred_{osp.basename(img_path)}'),
pred_mask)
Finnaly, just train the model with Runner!
from torch.optim import AdamW
from mmengine.optim import AmpOptimWrapper
from mmengine.runner import Runner
num_classes = 32 # Modify to actual number of categories.
runner = Runner(
model=MMDeeplabV3(num_classes),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(
type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)),
train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=IoU),
custom_hooks=[SegVisHook('data/CamVid')],
default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=1)),
)
runner.train()
Finnaly, you can check the training results in the folder ./work_dir/{timestamp}/vis_data
.
image | prediction | label |
---|---|---|