import torchvision
import torch
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import os
# to make the transform usable by torchvision dataset it needs to be a function that takes an image as input and return an image as well
datarootdir = os.environ.get("DATA_ROOTPATH")
def train_trans(image)->torch.tensor:
transform = A.Compose([
A.HorizontalFlip(),
A.Normalize(),
ToTensorV2()
])
transformed = transform(image = np.array(image))
return transformed["image"]
def test_trans(image)->torch.tensor:
transform = A.Compose([
A.Normalize(),
ToTensorV2()
])
transformed = transform(image = np.array(image))
return transformed["image"]
train_set = torchvision.datasets.CIFAR10(
root=os.path.join(datarootdir, "cifar10"),
download=True,
train=True,
transform=train_trans)
val_set = torchvision.datasets.CIFAR10(
root=os.path.join(datarootdir, "cifar10"),
download=True,
train=False,
transform=test_trans)
train_loader= torch.utils.data.DataLoader(
train_set,
# shuffle=True,
sampler = torch.utils.data.SubsetRandomSampler(np.random.choice(len(train_set), 10000)),
batch_size=64,
num_workers=5,
)
val_loader= torch.utils.data.DataLoader(
val_set,
shuffle=False,
batch_size=64*2,
num_workers=5,
)CIFAR is a trivial problem in image classification. We will be using Pytorch and lightning in order to do the training.
The advantage of this approach, is that the workflow can be done locally on the cpu of your computer or on ten H100 of any cloud you could get access to.
Lightning handles the location of data and optimization related objects (model, optimizer, scheduler etc…), and last be not least, the metrics computation done with torchmetrics.
The metrics have the gathering across gpus/devices already implemented so you just have to decide of which ones you want to add to your project. If some computations are not already present in the library, you can add your own metric very easily.
The data
The model
Code
import lightning as L
from typing import Optional, List
import torchmetrics
from omegaconf import DictConfig, OmegaConf
class ClassificationModule(L.LightningModule):
def __init__(
self,
categories :List[str],
config:DictConfig,
model: Optional[torch.nn.Module] = None,
):
super().__init__()
self.categories = categories
num_classes = len(categories)
self.config = config
self.model = model
if model is None:
self.model = torchvision.models.resnet18(num_classes=num_classes)
self.criterion = torch.nn.CrossEntropyLoss()
metrics = torchmetrics.MetricCollection([
torchmetrics.classification.Accuracy(task = "multiclass", num_classes = num_classes),
torchmetrics.F1Score(task = "multiclass", num_classes = num_classes),
torchmetrics.Precision(task = "multiclass", num_classes = num_classes),
torchmetrics.Recall(task = "multiclass", num_classes = num_classes),
torchmetrics.CalibrationError(task = "multiclass", num_classes = num_classes),
])
self.train_metric = metrics.clone(prefix="Train/")
self.val_metrics = metrics.clone(prefix="Validation/")
self.test_metrics = metrics.clone(prefix="Test/")
# conditional performances of our estimator
self.per_category_metrics = torchmetrics.MetricCollection([
torchmetrics.classification.Accuracy(task = "multiclass", num_classes = num_classes, average = None),
torchmetrics.F1Score(task = "multiclass", num_classes = num_classes, average = None),
torchmetrics.Precision(task = "multiclass", num_classes = num_classes, average = None),
torchmetrics.Recall(task = "multiclass", num_classes = num_classes, average = None),
])
def forward(self, X):
return self.model(X)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr = self.config.lr, weight_decay=1e-5)
# you can add a scheduler here as well and return it as
# return [optimizer], [scheduler]
#
return optimizer
def training_step(self, batch, batch_idx):
images, targets = batch
outputs = self(images)
loss = self.criterion(outputs, targets)
self.train_metric(outputs, targets)
self.log("Train/Loss",loss, on_epoch=True, on_step=True, prog_bar=True)
return loss
def on_train_epoch_end(self):
train_metrics= self.train_metric.compute()
self.log_dict(train_metrics)
self.train_metric.reset()
def validation_step(self, batch, batch_idx):
images, targets = batch
outputs = self(images)
loss = self.criterion(outputs, targets)
self.log("Validation/Loss", loss, on_epoch=True, on_step=False)
self.val_metrics(outputs, targets)
self.per_category_metrics(outputs, targets)
def on_validation_epoch_end(self):
val_metrics = self.val_metrics.compute()
self.log_dict(val_metrics)
m = self.per_category_metrics.compute()
for mname, mresults in m.items():
for i, catname in enumerate(self.categories):
self.log(f"Validation/{mname}_{catname}", mresults[i])
self.val_metrics.reset()
self.per_category_metrics.reset()
def test_step(self, batch, batch_idx):
images, targets = batch
outputs = self(images)
loss = self.criterion(outputs, targets)
self.log("Test/Loss", loss, on_epoch=True, on_step=False)
self.test_metrics(outputs, targets)
self.per_category_metrics(outputs, targets)
def on_test_epoch_end(self):
test_metrics = self.test_metrics.compute()
self.log_dict(test_metrics)
m = self.per_category_metrics.compute()
for mname, mresults in m.items():
for i, catname in enumerate(self.categories):
self.log(f"Validation/{mname}_{catname}", mresults[i])
self.test_metrics.reset()
self.per_category_metrics.reset()
config = OmegaConf.create({
"lr": 1e-5
})
model = ClassificationModule(
categories=train_set.classes,
config=config
)Use everything for training !
trainer= L.Trainer(
max_epochs=3,
precision = "16-mixed",
enable_checkpointing=True,
num_sanity_val_steps=2,
log_every_n_steps=50,
check_val_every_n_epoch=1,
)
# trainer.fit(
# model,
# train_loader,
# val_loader
# )Using 16bit Automatic Mixed Precision (AMP)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
And it is Done !
The weights of the model are saved with the config that produced them.