import torchvision
import torch
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
# to make the transform usable by torchvision dataset it needs to be a function that takes an image as input and return an image as well
def train_trans(image)->torch.tensor:
transform = A.Compose([
A.HorizontalFlip(),
A.Normalize(),
ToTensorV2()
])
transformed = transform(image = np.array(image))
return transformed["image"]
def test_trans(image)->torch.tensor:
transform = A.Compose([
A.Normalize(),
ToTensorV2()
])
transformed = transform(image = np.array(image))
return transformed["image"]
train_set = torchvision.datasets.CIFAR10(
root="data",
download=True,
train=True,
transform=train_trans)
val_set = torchvision.datasets.CIFAR10(
root="data",
download=True,
train=False,
transform=test_trans)
train_loader= torch.utils.data.DataLoader(
train_set,
# shuffle=True,
sampler = torch.utils.data.SubsetRandomSampler(np.random.choice(len(train_set), 10000)),
batch_size=64,
num_workers=5,
)
val_loader= torch.utils.data.DataLoader(
val_set,
shuffle=False,
batch_size=64*2,
num_workers=5,
)CIFAR is a trivial problem in image classification. We will be using Pytorch and lightning in order to do the training.
The advantage of this approach, is that the workflow can be done locally one the cpu of your computer or on ten H100 of any cloud you could get access to.
Lightning handles the location of data and optimization related objects (model, optimizer, scheduler etc…), and last be not least, the metrics computation done with torchmetrics.
The metrics have the gathering across gpus/device already implemented so you just have to decide of which ones you want to add to your project. If some computations are not already present in the library, you can add your own metric very easily.
The data
The model
Code
import lightning as L
from typing import Optional, List
import torchmetrics
from omegaconf import DictConfig, OmegaConf
class ClassificationModule(L.LightningModule):
def __init__(
self,
categories :List[str],
config:DictConfig,
model: Optional[torch.nn.Module] = None,
):
super().__init__()
self.categories = categories
num_classes = len(categories)
self.config = config
if model is None:
self.model = torchvision.models.resnet18(num_classes=num_classes)
self.criterion = torch.nn.CrossEntropyLoss()
metrics = torchmetrics.MetricCollection([
torchmetrics.classification.Accuracy(task = "multiclass", num_classes = num_classes),
torchmetrics.F1Score(task = "multiclass", num_classes = num_classes),
torchmetrics.Precision(task = "multiclass", num_classes = num_classes),
torchmetrics.Recall(task = "multiclass", num_classes = num_classes),
torchmetrics.CalibrationError(task = "multiclass", num_classes = num_classes),
])
self.train_metric = metrics.clone(prefix="Train/")
self.val_metrics = metrics.clone(prefix="Validation/")
self.test_metrics = metrics.clone(prefix="Test/")
self.per_category_metrics = torchmetrics.MetricCollection([
torchmetrics.classification.Accuracy(task = "multiclass", num_classes = num_classes, average = None),
torchmetrics.F1Score(task = "multiclass", num_classes = num_classes, average = None),
torchmetrics.Precision(task = "multiclass", num_classes = num_classes, average = None),
torchmetrics.Recall(task = "multiclass", num_classes = num_classes, average = None),
])
def forward(self, X):
return self.model(X)
def configure_optimizers(self):
# Define Optimizer here
optimizer = torch.optim.Adam(self.parameters(), lr = self.config.lr, weight_decay=1e-5)
# you cna add a scheduler here as well and return it as
# return [optimizer], [scheduler]
#
return optimizer
def training_step(self, batch, batch_idx):
images, targets = batch
outputs = self(images)
loss = self.criterion(outputs, targets)
self.train_metric(outputs, targets)
self.log("Train/Loss",loss, on_epoch=True, on_step=True, prog_bar=True)
return loss
def on_train_epoch_end(self):
train_metrics= self.train_metric.compute()
self.log_dict(train_metrics)
self.train_metric.reset()
def validation_step(self, batch, batch_idx):
images, targets = batch
outputs = self(images)
loss = self.criterion(outputs, targets)
self.log("Validation/Loss", loss, on_epoch=True, on_step=False)
self.val_metrics(outputs, targets)
self.per_category_metrics(outputs, targets)
def on_validation_epoch_end(self):
val_metrics = self.val_metrics.compute()
self.log_dict(val_metrics)
m = self.per_category_metrics.compute()
for mname, mresults in m.items():
for i, catname in enumerate(self.categories):
self.log(f"Validation/{mname}_{catname}", mresults[i])
self.val_metrics.reset()
self.per_category_metrics.reset()
def test_step(self, batch, batch_idx):
images, targets = batch
outputs = self(images)
loss = self.criterion(outputs, targets)
self.log("Test/Loss", loss, on_epoch=True, on_step=False)
self.test_metrics(outputs, targets)
self.per_category_metrics(outputs, targets)
def on_test_epoch_end(self):
test_metrics = self.test_metrics.compute()
self.log_dict(test_metrics)
m = self.per_category_metrics.compute()
for mname, mresults in m.items():
for i, catname in enumerate(self.categories):
self.log(f"Validation/{mname}_{catname}", mresults[i])
self.test_metrics.reset()
self.per_category_metrics.reset()
config = OmegaConf.create({
"lr": 1e-5
})
model = ClassificationModule(
categories=train_set.classes,
config=config
)## Use everything for train
trainer= L.Trainer(
max_epochs=3,
precision = "16-mixed",
enable_checkpointing=True,
num_sanity_val_steps=2,
log_every_n_steps=50,
check_val_every_n_epoch=1,
)
# trainer.fit(
# model,
# train_loader,
# val_loader
# )Using 16bit Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
And it is Done !
The weights of the model are saved with the config that produced them.