import torchvision
import torch
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
# to make the transform usable by torchvision dataset it needs to be a function that takes an image as input and return an image as well
def train_trans(image)->torch.tensor:
= A.Compose([
transform
A.HorizontalFlip(),
A.Normalize(),
ToTensorV2()
])
= transform(image = np.array(image))
transformed
return transformed["image"]
def test_trans(image)->torch.tensor:
= A.Compose([
transform
A.Normalize(),
ToTensorV2()
])
= transform(image = np.array(image))
transformed
return transformed["image"]
= torchvision.datasets.CIFAR10(
train_set ="data",
root=True,
download=True,
train=train_trans)
transform
= torchvision.datasets.CIFAR10(
val_set ="data",
root=True,
download=False,
train=test_trans)
transform
= torch.utils.data.DataLoader(
train_loader
train_set,# shuffle=True,
= torch.utils.data.SubsetRandomSampler(np.random.choice(len(train_set), 10000)),
sampler =64,
batch_size=5,
num_workers
)
= torch.utils.data.DataLoader(
val_loader
val_set,=False,
shuffle=64*2,
batch_size=5,
num_workers )
CIFAR is a trivial problem in image classification. We will be using Pytorch and lightning in order to do the training.
The advantage of this approach, is that the workflow can be done locally one the cpu of your computer or on ten H100 of any cloud you could get access to.
Lightning handles the location of data and optimization related objects (model, optimizer, scheduler etc…), and last be not least, the metrics computation done with torchmetrics.
The metrics have the gathering across gpus/device already implemented so you just have to decide of which ones you want to add to your project. If some computations are not already present in the library, you can add your own metric very easily.
The data
The model
Code
import lightning as L
from typing import Optional, List
import torchmetrics
from omegaconf import DictConfig, OmegaConf
class ClassificationModule(L.LightningModule):
def __init__(
self,
str],
categories :List[
config:DictConfig,= None,
model: Optional[torch.nn.Module]
):
super().__init__()
self.categories = categories
= len(categories)
num_classes self.config = config
if model is None:
self.model = torchvision.models.resnet18(num_classes=num_classes)
self.criterion = torch.nn.CrossEntropyLoss()
= torchmetrics.MetricCollection([
metrics = "multiclass", num_classes = num_classes),
torchmetrics.classification.Accuracy(task = "multiclass", num_classes = num_classes),
torchmetrics.F1Score(task = "multiclass", num_classes = num_classes),
torchmetrics.Precision(task = "multiclass", num_classes = num_classes),
torchmetrics.Recall(task = "multiclass", num_classes = num_classes),
torchmetrics.CalibrationError(task
])
self.train_metric = metrics.clone(prefix="Train/")
self.val_metrics = metrics.clone(prefix="Validation/")
self.test_metrics = metrics.clone(prefix="Test/")
self.per_category_metrics = torchmetrics.MetricCollection([
= "multiclass", num_classes = num_classes, average = None),
torchmetrics.classification.Accuracy(task = "multiclass", num_classes = num_classes, average = None),
torchmetrics.F1Score(task = "multiclass", num_classes = num_classes, average = None),
torchmetrics.Precision(task = "multiclass", num_classes = num_classes, average = None),
torchmetrics.Recall(task
])
def forward(self, X):
return self.model(X)
def configure_optimizers(self):
# Define Optimizer here
= torch.optim.Adam(self.parameters(), lr = self.config.lr, weight_decay=1e-5)
optimizer
# you cna add a scheduler here as well and return it as
# return [optimizer], [scheduler]
#
return optimizer
def training_step(self, batch, batch_idx):
= batch
images, targets
= self(images)
outputs
= self.criterion(outputs, targets)
loss
self.train_metric(outputs, targets)
self.log("Train/Loss",loss, on_epoch=True, on_step=True, prog_bar=True)
return loss
def on_train_epoch_end(self):
= self.train_metric.compute()
train_metrics
self.log_dict(train_metrics)
self.train_metric.reset()
def validation_step(self, batch, batch_idx):
= batch
images, targets
= self(images)
outputs
= self.criterion(outputs, targets)
loss self.log("Validation/Loss", loss, on_epoch=True, on_step=False)
self.val_metrics(outputs, targets)
self.per_category_metrics(outputs, targets)
def on_validation_epoch_end(self):
= self.val_metrics.compute()
val_metrics
self.log_dict(val_metrics)
= self.per_category_metrics.compute()
m for mname, mresults in m.items():
for i, catname in enumerate(self.categories):
self.log(f"Validation/{mname}_{catname}", mresults[i])
self.val_metrics.reset()
self.per_category_metrics.reset()
def test_step(self, batch, batch_idx):
= batch
images, targets
= self(images)
outputs
= self.criterion(outputs, targets)
loss self.log("Test/Loss", loss, on_epoch=True, on_step=False)
self.test_metrics(outputs, targets)
self.per_category_metrics(outputs, targets)
def on_test_epoch_end(self):
= self.test_metrics.compute()
test_metrics
self.log_dict(test_metrics)
= self.per_category_metrics.compute()
m for mname, mresults in m.items():
for i, catname in enumerate(self.categories):
self.log(f"Validation/{mname}_{catname}", mresults[i])
self.test_metrics.reset()
self.per_category_metrics.reset()
= OmegaConf.create({
config "lr": 1e-5
})
= ClassificationModule(
model =train_set.classes,
categories=config
config )
## Use everything for train
= L.Trainer(
trainer=3,
max_epochs= "16-mixed",
precision =True,
enable_checkpointing=2,
num_sanity_val_steps=50,
log_every_n_steps=1,
check_val_every_n_epoch
)
# trainer.fit(
# model,
# train_loader,
# val_loader
# )
Using 16bit Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
And it is Done !
The weights of the model are saved with the config that produced them.