DeepLearning
code
ImageClassification
ComputerVision
ENG
Author

Julien Combes

Published

April 13, 2025

CIFAR is a trivial problem in image classification. We will be using Pytorch and lightning in order to do the training.

The advantage of this approach, is that the workflow can be done locally one the cpu of your computer or on ten H100 of any cloud you could get access to.

Lightning handles the location of data and optimization related objects (model, optimizer, scheduler etc…), and last be not least, the metrics computation done with torchmetrics.

The metrics have the gathering across gpus/device already implemented so you just have to decide of which ones you want to add to your project. If some computations are not already present in the library, you can add your own metric very easily.

The data

import torchvision
import torch
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
# to make the transform usable by torchvision dataset it needs to be a function that takes an image as input and return an image as well



def train_trans(image)->torch.tensor:
    transform = A.Compose([
        A.HorizontalFlip(),
        A.Normalize(),
        ToTensorV2()
    ]) 

    transformed = transform(image = np.array(image))

    return transformed["image"]

def test_trans(image)->torch.tensor:
    transform = A.Compose([
        A.Normalize(),
        ToTensorV2()
    ]) 

    transformed = transform(image = np.array(image))

    return transformed["image"]

train_set = torchvision.datasets.CIFAR10(
    root="data", 
    download=True, 
    train=True,
    transform=train_trans)

val_set = torchvision.datasets.CIFAR10(
    root="data", 
    download=True, 
    train=False,
    transform=test_trans)

train_loader= torch.utils.data.DataLoader(
    train_set,
    # shuffle=True,
    sampler = torch.utils.data.SubsetRandomSampler(np.random.choice(len(train_set), 10000)),
    batch_size=64,
    num_workers=5,

)

val_loader= torch.utils.data.DataLoader(
    val_set,
    shuffle=False,
    batch_size=64*2,
    num_workers=5,
)

The model

Code
import lightning as L
from typing import Optional, List
import torchmetrics
from omegaconf import DictConfig, OmegaConf

class ClassificationModule(L.LightningModule):
    def __init__(
        self, 
        categories :List[str],
        config:DictConfig,
        model: Optional[torch.nn.Module] = None, 
        ):
        
        super().__init__()
        self.categories = categories
        num_classes = len(categories)
        self.config = config
        if model is None:
            self.model = torchvision.models.resnet18(num_classes=num_classes)

        self.criterion = torch.nn.CrossEntropyLoss()

        metrics = torchmetrics.MetricCollection([
            torchmetrics.classification.Accuracy(task = "multiclass", num_classes = num_classes),
            torchmetrics.F1Score(task = "multiclass", num_classes = num_classes),
            torchmetrics.Precision(task = "multiclass", num_classes = num_classes),
            torchmetrics.Recall(task = "multiclass", num_classes = num_classes),
            torchmetrics.CalibrationError(task = "multiclass", num_classes = num_classes),
        ])

        self.train_metric = metrics.clone(prefix="Train/")
        self.val_metrics = metrics.clone(prefix="Validation/")
        self.test_metrics = metrics.clone(prefix="Test/")

        self.per_category_metrics = torchmetrics.MetricCollection([
            torchmetrics.classification.Accuracy(task = "multiclass", num_classes = num_classes, average = None),
            torchmetrics.F1Score(task = "multiclass", num_classes = num_classes, average = None),
            torchmetrics.Precision(task = "multiclass", num_classes = num_classes, average = None),
            torchmetrics.Recall(task = "multiclass", num_classes = num_classes, average = None),
        ])

    def forward(self, X):
        return self.model(X)

    def configure_optimizers(self):

        # Define Optimizer here
        optimizer = torch.optim.Adam(self.parameters(), lr = self.config.lr, weight_decay=1e-5)

        # you cna add a scheduler here as well and return it as 
        # return [optimizer], [scheduler]
        # 
        return optimizer

    def training_step(self, batch, batch_idx):
        images, targets = batch

        outputs = self(images)

        loss = self.criterion(outputs, targets)
        
        self.train_metric(outputs, targets)

        self.log("Train/Loss",loss, on_epoch=True, on_step=True, prog_bar=True)

        return loss
    
    def on_train_epoch_end(self):

        train_metrics=  self.train_metric.compute()

        self.log_dict(train_metrics)

        self.train_metric.reset()
    
    def validation_step(self, batch, batch_idx):
        images, targets = batch

        outputs = self(images)

        loss = self.criterion(outputs, targets)
        self.log("Validation/Loss", loss, on_epoch=True, on_step=False)

        self.val_metrics(outputs, targets)
        self.per_category_metrics(outputs, targets)

        
    
    def on_validation_epoch_end(self):

        val_metrics =  self.val_metrics.compute()

        self.log_dict(val_metrics)

        m = self.per_category_metrics.compute()
        for mname, mresults in m.items():
            for i, catname in enumerate(self.categories):
                self.log(f"Validation/{mname}_{catname}", mresults[i])

        self.val_metrics.reset()
        self.per_category_metrics.reset()
    

    def test_step(self, batch, batch_idx):
        images, targets = batch

        outputs = self(images)

        loss = self.criterion(outputs, targets)
        self.log("Test/Loss", loss, on_epoch=True, on_step=False)

        self.test_metrics(outputs, targets)
        self.per_category_metrics(outputs, targets)

        
    
    def on_test_epoch_end(self):

        test_metrics =  self.test_metrics.compute()

        self.log_dict(test_metrics)
        m = self.per_category_metrics.compute()
        for mname, mresults in m.items():
            for i, catname in enumerate(self.categories):
                self.log(f"Validation/{mname}_{catname}", mresults[i])

        self.test_metrics.reset()
        self.per_category_metrics.reset()


config = OmegaConf.create({
    "lr": 1e-5
})

model = ClassificationModule(
    categories=train_set.classes,
    config=config
)

## Use everything for train

trainer=  L.Trainer(
    max_epochs=3,
    precision = "16-mixed",
    enable_checkpointing=True,
    num_sanity_val_steps=2,
    log_every_n_steps=50,
    check_val_every_n_epoch=1,
)

# trainer.fit(
#     model,
#     train_loader,
#     val_loader
# )
Using 16bit Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

And it is Done !

The weights of the model are saved with the config that produced them.