---
title: "Reliable Image classifiers"
author: "Julien Combes"
date: "2025-12-13"
categories: [DeepLearning, ENG, ConformalPrediction, code]
image: "image.png"
---
Let's make a resnet18 make reliable predictions on cifars datasets !
# Motivation and method
Given the amount of data generated since the 2000s, modelling processes and understanding the underlying patterns in these numbers has become increasingly difficult. The time of modeling with linear relations and homoskedastik errors has passed for most of the advanced use cases. The number of different models we could use for any problems in big and the field of axtracting information from data have been taken over by our brothers the computer scientists. Today, the engineers and applied researchers are more result oriented than method oriented, meaning that if the model makes few mistake we just deploy it, period. Thisi new empirical methodology of modeling with any possible model in order to get the good one leads to issues in reliability.
Indeed, most advanced machine learning models, such as boosted trees, ensembles and the latest neural networks, do not provide any way to estimate the uncertainty associated with their predictions. When it comes to predicting diseases from medical images, understanding the confidence in your diagnosis is essential.
That is why we need a statistical tool that can be used with any model and any type of data, while making as few assumptions as possible.
Conformal Prediction is the tool we need to accomplish that ! (i am still a noob but willing to learn about this :) )
First publication by Vovk, Vapnik & al [@gammermanLearningTransduction1998], a detailed history about the field is given by Angelopoulos in his gentle introduction to conformal prediction [@angelopoulosGentleIntroductionConformal2022].
In short, conformal prediction is a tool that is **model independent** and **independent on the data distribution**. The only assumption on of this framework is that the data points are **exchangeable**, meaning that the vector of data points and any of its permutation has the same distribution (for any distribution).
It is particularly beneficial for neural network that is the type of model that we favor when dealing with unstructured data like images (in medical imaging or machine vision), these fields require extreme robustness and are expected to be reliable. That is why i believe conformal prediction could be very helpful in my work. This tool gives statistical garantees on on the prediction to be sure to predict the true label with a required probability.
In the next part, i show you how to implement it in for image classification using a plug and play Lightning Callback on a pre existing resnet18. In classification, conformal predictors creates prediction sets that have a high probability containing the true label. We want the prediction set to be as small as possible, meaning that the model is confident in its prediction for this input. We evaluate the quality of the conformal predictor using the coverage to make sure the desired coverage is verified, the average set size and the proportion of singleton predicted.
# Experimental Setup
Here we load data and models using torch and lightning.
```{python}
#| code-fold: true
import torchvision
import torch
import math
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import os
import lightning as L
from typing import Optional, List
import torchmetrics
from omegaconf import DictConfig, OmegaConf
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import Callback
from torchcp.classification.predictor import SplitPredictor
from torchcp.classification.score import APS, RAPS, SAPS
from torchmetrics import Metric
# to make the transform usable by torchvision dataset it needs to be a function that takes an image as input and return an image as well
datarootdir = os.environ.get("DATA_ROOTPATH")
def train_trans(image)->torch.tensor:
transform = A.Compose([
A.HorizontalFlip(),
A.Normalize(),
ToTensorV2()
])
transformed = transform(image = np.array(image))
return transformed["image"]
def test_trans(image)->torch.tensor:
transform = A.Compose([
A.Normalize(),
ToTensorV2()
])
transformed = transform(image = np.array(image))
return transformed["image"]
train_set = torchvision.datasets.CIFAR10(
root=os.path.join(datarootdir, "cifar10"),
download=False,
train=True,
transform=train_trans)
val_set = torchvision.datasets.CIFAR10(
root=os.path.join(datarootdir, "cifar10"),
download=False,
train=False,
transform=test_trans)
train_loader= torch.utils.data.DataLoader(
train_set,
shuffle=True,
batch_size=64,
num_workers=5,
)
val_indices = np.random.choice(len(val_set), round(len(val_set)/2), replace= False)
test_indices= np.setdiff1d(np.arange(len(val_set)), val_indices)
val_loader= torch.utils.data.DataLoader(
torch.utils.data.Subset(val_set,val_indices),
shuffle=False,
batch_size=64*2,
num_workers=5,
)
test_loader= torch.utils.data.DataLoader(
torch.utils.data.Subset(val_set,test_indices),
shuffle=False,
batch_size=64*2,
num_workers=5,
)
class ClassificationModule(L.LightningModule):
def __init__(
self,
categories :List[str],
config:DictConfig,
):
super().__init__()
self.categories = categories
num_classes = len(categories)
self.config = config
model = torchvision.models.resnet18(
weights="IMAGENET1K_V1")
model.fc= torch.nn.Linear(in_features = 512, out_features = num_classes)
self.model = model
self.criterion = torch.nn.CrossEntropyLoss()
metrics = torchmetrics.MetricCollection([
torchmetrics.classification.Accuracy(task = "multiclass", num_classes = num_classes),
torchmetrics.F1Score(task = "multiclass", num_classes = num_classes),
torchmetrics.Precision(task = "multiclass", num_classes = num_classes),
torchmetrics.Recall(task = "multiclass", num_classes = num_classes),
torchmetrics.CalibrationError(task = "multiclass", num_classes = num_classes),
])
self.train_metric = metrics.clone(prefix="Train/")
self.val_metrics = metrics.clone(prefix="Validation/")
self.test_metrics = metrics.clone(prefix="Test/")
# conditional performances of our estimator
self.per_category_metrics = torchmetrics.MetricCollection([
torchmetrics.classification.Accuracy(task = "multiclass", num_classes = num_classes, average = None),
torchmetrics.F1Score(task = "multiclass", num_classes = num_classes, average = None),
torchmetrics.Precision(task = "multiclass", num_classes = num_classes, average = None),
torchmetrics.Recall(task = "multiclass", num_classes = num_classes, average = None),
])
def forward(self, X):
return self.model(X)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr = self.config.lr, weight_decay=1e-5, momentum=0.9)
# you can add a scheduler here as well and return it as
# return [optimizer], [scheduler]
#
return optimizer
def training_step(self, batch, batch_idx):
images, targets = batch
outputs = self(images)
loss = self.criterion(outputs, targets)
self.train_metric(outputs, targets)
self.log("Train/Loss",loss, on_epoch=True, on_step=True, prog_bar=True)
return loss
def on_train_epoch_end(self):
train_metrics= self.train_metric.compute()
self.log_dict(train_metrics)
self.train_metric.reset()
def validation_step(self, batch, batch_idx):
images, targets = batch
outputs = self(images)
loss = self.criterion(outputs, targets)
self.log("Validation/Loss", loss, on_epoch=True, on_step=False)
self.val_metrics(outputs, targets)
self.per_category_metrics(outputs, targets)
return outputs
def on_validation_epoch_end(self):
val_metrics = self.val_metrics.compute()
self.log_dict(val_metrics)
m = self.per_category_metrics.compute()
for mname, mresults in m.items():
for i, catname in enumerate(self.categories):
self.log(f"Validation/{mname}_{catname}", mresults[i])
self.val_metrics.reset()
self.per_category_metrics.reset()
def test_step(self, batch, batch_idx):
images, targets = batch
outputs = self(images)
loss = self.criterion(outputs, targets)
self.log("Test/Loss", loss, on_epoch=True, on_step=False)
self.test_metrics(outputs, targets)
self.per_category_metrics(outputs, targets)
return outputs
def on_test_epoch_end(self):
test_metrics = self.test_metrics.compute()
# self.log_dict(test_metrics)
m = self.per_category_metrics.compute()
# for mname, mresults in m.items():
# for i, catname in enumerate(self.categories):
# self.log(f"Test/{mname}_{catname}", mresults[i])
self.test_metrics.reset()
self.per_category_metrics.reset()
config = OmegaConf.create({
"lr": 1e-3
})
# model = ClassificationModule(
# categories=train_set.classes,
# config=config
# )
# trainer= L.Trainer(
# max_epochs=100,
# precision = "16-mixed",
# enable_checkpointing=True,
# num_sanity_val_steps=0,
# log_every_n_steps=50,
# check_val_every_n_epoch=1,
# callbacks=[
# EarlyStopping("Validation/Loss", patience=2,mode="min"),
# ModelCheckpoint(monitor="Validation/Loss", filename='cifar10-{epoch:02d}')
# ]
# )
# I wont train it each time
# trainer.fit(
# model,
# train_loader,
# val_loader
# )
# trainer.test(model,val_loader)
model= ClassificationModule.load_from_checkpoint(
"lightning_logs/version_36/checkpoints/cifar10-epoch=03.ckpt",
categories=train_set.classes,
config=config
)
```
# Calibration
Calibration is the property of an estimator that allow it to quantify its uncertainty in a reliable manner. In classification for example, when the model tells us that he thinks the class it predicted is true with a 60% probability, it should be right 60% of the time. No model has such guarantee out of the box, that is why we need tools to quantify the models uncertainty with good statistical guarantee without relying on strong asymption about the model or the data.
Here is the calibration with this network for each class of the cifar10 dataset (the net is pre trained on image net so he is actually a bit too good haha). The following plots are the reliability curves for the net for each class. The more it is close to the diagonal $y=x$ the better the calibration. The expected calibration error (ECE) is computed as well.
```{python}
#| layout-ncol: 5
#| echo: false
import relplot as rp
preds = []
targets = []
with torch.no_grad():
for b in test_loader:
im, tar = b
pred = model(im.cuda()).softmax(1)
preds.append(pred)
targets.append(tar)
preds = torch.cat(preds)
targets = torch.cat(targets)
for c in range(10):
threshold = 0.02
targets_class = (targets==c).int().cpu()
preds_class = preds[:,c].cpu()
t_preds = preds_class[preds_class > threshold]
t_targets = targets_class[preds_class > threshold]
rp.rel_diagram_binned(
t_preds,
t_targets,
nbins=30,
)
plt.title(f"Class {c}")
```
Here we calibrate our conformal predictor using a split predictor, the calibration data will be the validation dataset.
We code it through a lightning callbacks. It is a simple interface that will automatically loop along the dataloader provided to the validate method.
The implementation through Lightning trainer comes from this paper that applied conformal prediction method in semantic segmentation [@brunekreefKandinskyConformalPrediction2023]
```{python}
class CalibratorCallback(Callback):
def __init__(self, score_function:str = "aps", alpha:float= 0.1):
match score_function:
case "aps":
self.score_fn = APS(score_type="softmax")
case "raps":
self.score_fn = RAPS(score_type="softmax", penalty = 0.1, kreg = 5)
case "saps":
self.score_fn = SAPS(score_type="softmax")
self.alpha = alpha
self.pred_set_sizes = []
self.n_singletons = []
self.covered = []
def on_validation_start(self,trainer, pl_module):
self.nc_scores = []
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
im, targets= batch
batch_nc_scores = self.score_fn(outputs, targets)
self.nc_scores.extend(batch_nc_scores)
def on_validation_end(self, trainer, pl_module):
self.nc_scores = torch.tensor(self.nc_scores)
N = self.nc_scores.shape[0]
quantile_value = math.ceil((N + 1) * (1 - self.alpha)) / N
q_hat = torch.kthvalue(self.nc_scores, math.ceil(N*quantile_value), dim=0).values.to(self.nc_scores.device)
pl_module.q_hat = q_hat
self.q_hat = q_hat
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
checkpoint["q_hat"] = self.q_hat
def on_test_start(self, trainer, pl_module):
assert self.q_hat
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int, dataloader_idx: int=0):
labels = batch[1]
batch_nc_scores = self.score_fn(outputs)
prediction_sets = (batch_nc_scores <= self.q_hat).int()
self.pred_set_sizes.extend(prediction_sets.sum(1))
self.n_singletons.extend(prediction_sets.sum(1)==1)
self.covered.extend(prediction_sets[torch.arange(len(labels)), labels])
def on_test_epoch_end(self, trainer, pl_module):
sizes = torch.tensor(self.pred_set_sizes).float().mean()
singletons = torch.tensor(self.n_singletons).float().mean()
covered = torch.tensor(self.covered).float().mean()
self.log("Test/covered",covered)
self.log("Test/sizes",sizes)
self.log("Test/singletons",singletons)
for score_fn in ["aps", "raps", "saps"]:
trainer= L.Trainer(
max_epochs=1,
precision = "16-mixed",
enable_checkpointing=False,
num_sanity_val_steps=0,
log_every_n_steps=0,
check_val_every_n_epoch=0,
callbacks=[
CalibratorCallback(alpha= 0.1, score_function=score_fn)
]
)
trainer.validate(model, val_loader, verbose = False);
trainer.test(model, test_loader);
```
After computing the non conformity scores on our calibration test we can look at the image from the calibration sets that have the highest and the lowest conformity score :
```{python}
#| layout-ncol: 2
#| echo: false
# analysis of highest and lowest calibration points
val_set.transform = lambda x:x
max_score = trainer.callbacks[0].nc_scores.argmax()
image , target = torch.utils.data.Subset(val_set,val_indices)[max_score]
plt.imshow(
np.array(image)
)
plt.title(f"Most non conformative calibration image : {val_set.classes[target]}")
plt.axis("off")
plt.figure()
min_score = trainer.callbacks[0].nc_scores.argmin()
image , target = torch.utils.data.Subset(val_set,val_indices)[min_score]
plt.imshow(
np.array(image)
)
plt.title(f"Most conformative calibration image : {val_set.classes[target]}")
plt.axis("off")
```
When working with conformal prediction we want to track the characteristics of our predictor. We are interested in :
- the coverage rate (the number of time the true label was indeed in the prediction set),
- the proportion of singletons (the proportion of prediction sets containing only only one element).
- The average prediction sets size.
We can see that the coverate rate is equal to 89%, it a little lower than the coverage rate we would have wanted by asking for an $\alpha=10\%$. But here, we keep the same calibration dataloader for all the test point, which lower the coverage of the conformal predictor.
The coverage is still important.
60% of the test points are composed of singletons, which means that the network is sure about the predicted class with 90% reliability.
And lastly, the average size of the prediction is 1.58 which is low considering the number of classes (10).
All the different score functions provide the required coverage but (without any statistical backing for computational burden) it looks like the more advanced the algorithm is, the smaller the prediction sets. Which is why those scores exists, so it is a good news we witness what we are supposes to observe.
## More classes
Here we could try to run this but with more classes to see how this affect the size fo prediction sets etc...
```{python}
#| code-fold: true
bs = 32
train_set = torchvision.datasets.CIFAR100(
root=os.path.join(datarootdir, "cifar100"),
download=False,
train=True,
transform=train_trans)
val_set = torchvision.datasets.CIFAR100(
root=os.path.join(datarootdir, "cifar100"),
download=False,
train=False,
transform=test_trans)
train_loader= torch.utils.data.DataLoader(
train_set,
shuffle=True,
batch_size=bs,
num_workers=5,
)
val_indices = np.random.choice(len(val_set), round(len(val_set)/2), replace= False)
test_indices= np.setdiff1d(np.arange(len(val_set)), val_indices)
val_loader= torch.utils.data.DataLoader(
torch.utils.data.Subset(val_set,val_indices),
shuffle=False,
batch_size=bs*2,
num_workers=5,
)
test_loader= torch.utils.data.DataLoader(
torch.utils.data.Subset(val_set,test_indices),
shuffle=False,
batch_size=bs*2,
num_workers=5,
)
model = ClassificationModule(
categories=train_set.classes,
config=config
)
trainer= L.Trainer(
max_epochs=3,
precision = "16-mixed",
enable_checkpointing=True,
num_sanity_val_steps=0,
log_every_n_steps=50,
check_val_every_n_epoch=1,
callbacks=[
EarlyStopping("Validation/Loss", patience=2,mode="min"),
ModelCheckpoint(monitor="Validation/Loss", filename='cifar100-{epoch:02d}')
]
)
# I wont train it each time
# trainer.fit(
# model,
# train_loader,
# val_loader
# )
# trainer.test(model,val_loader)
model= ClassificationModule.load_from_checkpoint(
"lightning_logs/version_73/checkpoints/cifar100-epoch=02.ckpt",
categories=train_set.classes,
config=config
)
for score_fn in ["aps", "raps", "saps"]:
trainer= L.Trainer(
max_epochs=1,
precision = "16-mixed",
enable_checkpointing=False,
num_sanity_val_steps=0,
log_every_n_steps=0,
check_val_every_n_epoch=0,
callbacks=[
CalibratorCallback(alpha= 0.1, score_function=score_fn)
]
)
trainer.validate(model, val_loader, verbose = False);
trainer.test(model, test_loader);
```
Here the number of possible predicted classes is 100 so as we can see the proportion of singleton predicted is 20% at max or 0 for the SAPS scoring function. We can see than while providing no singleton, the average size of prediction sets is smaller. Relatively to $\#\mathcal{Y}$ being 100 of course, it is still large sets.
We can see that the coverage requested of 90% is still verifies that is very good.
```{python}
#| layout-ncol: 2
#| echo: false
# analysis of highest and lowest calibration points
val_set.transform = lambda x:x
max_score = trainer.callbacks[0].nc_scores.argmax()
image , target = torch.utils.data.Subset(val_set,val_indices)[max_score]
plt.imshow(
np.array(image)
)
plt.title(f"Most non conformative calibration image : {val_set.classes[target]}")
plt.axis("off")
plt.figure()
min_score = trainer.callbacks[0].nc_scores.argmin()
image , target = torch.utils.data.Subset(val_set,val_indices)[min_score]
plt.imshow(
np.array(image)
)
plt.title(f"Most conformative calibration image : {val_set.classes[target]}")
plt.axis("off")
```
# Conclusion
I am still learning about conformal prediction and this helped me a lot understanding the computation and the logic behinf the APS algorithm.
Being able to to it with lightning using callbacks is very nice, it looses the benefits of the interface with torchcp but we can rely on their computation of scores anyway so i would say it is a good trade off.
One remark i could make is that using callbacks makes it bad on HPC, i think implementing the different metrics we want to follow using custom torchmetrics with aggregation logic would allow a computation of conformalized prediction sets more efficiently. The issue with the calibration is still present.