Simple semantic segmentation example
We can observe simple example for segmentation task with one class (text).
First, you should import common libraries and packages below. (If you don’t have some package, than install it).
import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.optim as optim
import torch.nn as nn
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from albumentations.augmentations import *
from albumentations.core.composition import *
from albumentations.pytorch.transforms import *
from timm import create_model
import random
from torchmetrics import *
import shutil
Than you can import EasyPL packages, as like:
from easypl.learners import SegmentationLearner
from easypl.metrics import TorchMetric
from easypl.metrics.segmentation import PixelLevelF1
from easypl.optimizers import WrapperOptimizer
from easypl.lr_schedulers import WrapperScheduler
from easypl.datasets import CSVDatasetSegmentation
from easypl.callbacks import SegmentationImageLogger
from easypl.callbacks import Mixup
from easypl.losses.segmentation import DiceLoss
Than you should define datasets and dataloaders. You can use this simple example:
train_transform = Compose([
HorizontalFlip(p=0.5),
Rotate(p=0.5),
LongestMaxSize(max_size=224),
PadIfNeeded(min_height=224, min_width=224, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
Normalize(),
ToTensorV2(),
])
val_transform = Compose([
LongestMaxSize(max_size=600),
PadIfNeeded(min_height=600, min_width=600, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
Normalize(),
ToTensorV2(),
])
dataset = CSVDatasetSegmentation(
csv_path='../input/lan-segmentation-1/train.csv',
image_prefix='../input/lan-segmentation-1/train/images',
mask_prefix='../input/lan-segmentation-1/train/masks',
image_column='path',
target_column='path'
)
class WrapperDataset(Dataset):
def __init__(self, dataset, transform=None, idxs=[]):
self.dataset = dataset
self.transform = transform
self.idxs = idxs
def __getitem__(self, idx):
idx = self.idxs[idx]
row = self.dataset[idx]
row['mask'][row['mask'] < 150] = 0
row['mask'][row['mask'] > 150] = 1
if self.transform:
result = self.transform(image=row['image'], mask=row['mask'])
row['image'] = result['image']
row['mask'] = result['mask'].to(dtype=torch.long)
row['mask'] = one_hot(row['mask'], num_classes=2).permute(2, 0, 1)
return row
def __len__(self):
return len(self.idxs)
image_size = 768
train_transform = A.Compose([
A.HorizontalFlip(p=0.5),
A.ColorJitter(p=0.7),
A.LongestMaxSize(max_size=image_size),
A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
val_transform = A.Compose([
A.LongestMaxSize(max_size=image_size),
A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
size_dataset = len(dataset)
val_size = int(size_dataset * 0.1)
train_dataset = WrapperDataset(dataset, transform=train_transform, idxs=list(range(val_size, size_dataset)))
val_dataset = WrapperDataset(dataset, transform=val_transform, idxs=list(range(val_size)))
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)
Than we should define model (used timm), loss function, optimizer and metrics:
model = smp.UnetPlusPlus(
encoder_name="resnet18",
encoder_weights="imagenet",
in_channels=3,
decoder_use_batchnorm=False,
classes=2,
)
loss_f = DiceLoss(weight=torch.tensor([1, 10]))
optimizer = WrapperOptimizer(optim.Adam, lr=1e-4)
num_epochs = 7
num_gpus = 1
lr_scheduler = WrapperScheduler(
torch.optim.lr_scheduler.OneCycleLR, max_lr=3e-4, pct_start=1 / (num_epochs),
total_steps=int(len(train_dataloader) * num_epochs / num_gpus) + 10, div_factor=1e+3, final_div_factor=1e+4,
anneal_strategy='cos', interval='step'
)
class_names = ['background', 'text']
train_metrics = [
]
val_metrics = [
TorchMetric(PixelLevelF1(average='none', num_classes=len(class_names)), class_names),
]
If you need in callbacks, you can use our simple realization. Creating of callbacks looks like:
# Logger of outputs (images)
logger = SegmentationImageLogger(
phase='val',
max_samples=10,
num_classes=2,
save_on_disk=True,
dir_path='images'
)
# Cutmix callback
mixup = Mixup(
on_batch=True,
p=1.0,
domen='segmentation',
)
In finally, we should define learner and trainer, and than run training.
learner = SegmentationLearner(
model=model,
loss=loss_f,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
train_metrics=train_metrics,
val_metrics=val_metrics,
data_keys=['image'],
target_keys=['mask'],
multilabel=False
)
trainer = pl.Trainer(gpus=num_gpus, callbacks=[logger, mixup, checkpoint_callback], max_epochs=num_epochs)
trainer.fit(learner, train_dataloaders=train_dataloader, val_dataloaders=[val_dataloader])