from torchvision.datasets import VOCDetection
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
import torch
from omegaconf import OmegaConf
from torchvision.ops import box_convert
from torchvision.utils import draw_bounding_boxes
import albumentations as A
from transformers import DFineForObjectDetection, AutoImageProcessor
from albumentations.pytorch import ToTensorV2
from torchvision.ops import batched_nms
import torchmetrics
import lightning as L
from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn
from typing import Dict, List, Optional
import osObject detection is the second type of problem we will be solving in this blog using lightning and pascal voc. This problem is very common because it allow to locate a variable number of objects in images.
Since it is used in case of autonomous driving, real time object counting object detection architectures are grouped into two categories. The real time object detection where the boss is the YOLO version 11 at the time i write this article. The more recent networks that can fight with yolo in RTOD (Real time object detection) is the D-Fine architecture which is an improved DETR.
We implemented it in this article, we chose to train this architecture in lightning since the training loop can be re-used with other models and dataset. I don’t know what is the best training paradigm but i want one that keeps as agnostic as possible to any framework, and for me lighting allow strong customization and proximity to pure torchscript while allowing easy multi device training and metrics logging etc..
Data
We will model pascal voc images in this project because its very easily accessible from torchvision repos.
d-fine expectS bbox in coco format so we convert it in the get_item part.
categories = ['pottedplant', 'bottle', 'chair', 'diningtable', 'person', 'car', 'train', 'bus', 'bicycle', 'cat', 'aeroplane', 'tvmonitor', 'sofa', 'sheep', 'dog', 'bird', 'motorbike', 'horse', 'boat', 'cow', "bg"]
def unpack_box(box_dict:Dict):
"""
Unpack the box dictionary into a list of coordinates
"""
return torch.tensor(np.array([
box_dict["xmin"],
box_dict["ymin"],
box_dict["xmax"],
box_dict["ymax"]
], dtype = float))
def annotation_to_torch(target:Dict):
rep = {}
detections = target["annotation"]["object"]
rep["labels"] = np.array([categories.index(i["name"]) for i in detections])
# xmin ymin xmax ymax
rep["boxes"] = torch.stack([unpack_box(i["bndbox"]) for i in detections])
return rep
image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine_x_coco", use_fast=True)
class MyVoc(VOCDetection):
def __getitem__(self, index):
image, target = super().__getitem__(index)
# Apply your transformations here
target = annotation_to_torch(target)
transform = A.Compose([
A.PadIfNeeded(500,500),
A.HorizontalFlip(),
A.RandomCrop(400,400),
A.Resize(224, 224),
A.Normalize(normalization="min_max"),
ToTensorV2()
],
bbox_params=A.BboxParams(
format='pascal_voc', # Specify input format
label_fields=['class_labels'],
filter_invalid_bboxes=True)
)
transformed = transform(
image=np.array(image),
bboxes=target["boxes"],
class_labels=target["labels"]
)
transformed["labels"] = transformed["class_labels"]
transformed["boxes"] = transformed["bboxes"]
transformed.pop("bboxes")
image = transformed.pop("image")
transformed["boxes"] = box_convert(
torch.from_numpy(transformed["boxes"],
),
"xyxy",
"xywh",).float()
transformed["labels"] = torch.from_numpy(transformed["labels"])
transformed["class_labels"] = torch.from_numpy(transformed["class_labels"])
return image.float(), transformed
def draw_item(self, index:Optional[int] = None, n=5):
if index is None:
index = np.random.randint(0, len(self))
image, labels = self[index]
with_boxes = draw_bounding_boxes(
image = image,
boxes= box_convert(labels["boxes"], "xywh", "xyxy"),
labels = [categories[i] for i in labels["labels"]],
colors = "red"
)
plt.figure(figsize=(10, 10))
plt.imshow(with_boxes.permute(1, 2, 0).numpy())
plt.axis('off')
plt.show()
return Nothing special here, we show what the data looks like.
datarootdir = os.environ.get("DATA_ROOTPATH", "~/data_wsl/voc")
ds = MyVoc(
root = os.path.join(datarootdir, "voc"),
download = False,
image_set="train"
)
val_ds = MyVoc(
root = os.path.join(datarootdir, "voc"),
download = False,
image_set="val",
)
ds.draw_item()
Modeling
Modeling part is here, with the lightning module.
Code
config = OmegaConf.create({
"lr": 1e-4,
"batch_size":2,
"epochs":3,
"world_size":1
})
def apply_nms(preds:Dict, iou_thr:float = .5):
nms_indices = batched_nms(
preds["boxes"],
scores = preds["scores"],
idxs=preds["labels"],
iou_threshold=iou_thr
)
preds_nms = {}
preds_nms["boxes"] = preds["boxes"][nms_indices,:]
preds_nms["scores"] = preds["scores"][nms_indices]
preds_nms["labels"] = preds["labels"][nms_indices]
# high_scores_indices = preds_nms["scores"] > .3
# preds_nms["boxes"] = preds_nms["boxes"][high_scores_indices]
# preds_nms["scores"] = preds_nms["scores"][high_scores_indices]
# preds_nms["labels"] = preds_nms["labels"][high_scores_indices]
return preds_nms
class odModule(L.LightningModule):
def __init__(self, config, categories:List[str], nms_thr:float = .5):
super().__init__()
# if config.checkpoint is not None:
# print(f"checkpoint from {config.checkpoint}")
num_classes = len(categories)
self.categories = categories
self.nms_thr = nms_thr
self.config = config
# self.model = fasterrcnn_mobilenet_v3_large_fpn(pretrained = False, num_classes=num_classes)
self.model = DFineForObjectDetection.from_pretrained(
"ustc-community/dfine_x_coco",
id2label= {i:cat for i,cat in enumerate(categories)},
label2id={cat:i for i,cat in enumerate(categories)},
ignore_mismatched_sizes=True,
)
metrics = torchmetrics.MetricCollection(
[
torchmetrics.detection.mean_ap.MeanAveragePrecision(
# extended_summary=True,
iou_thresholds=np.linspace(0,1,20).tolist(),
class_metrics=True,
iou_type="bbox",
),
]
)
self.val_metrics = metrics.clone(prefix="Validation/")
self.test_metrics = metrics.clone(prefix="Test/")
self.save_hyperparameters(ignore=["train_ds"])
@staticmethod
def prepare_batch(batch):
images, targets = batch
return images, targets
def forward(self, x, y=None):
if y is not None:
return self.model(pixel_values = x, labels = y)
else:
preds = self.model(x)
n, c, h, w = x.shape
preds = image_processor.post_process_object_detection(
preds,
target_sizes=[(h,w) for _ in range(n)],
threshold=0.5)
return preds
def predict(self, x):
"""Forward the model then run NMS (for evaluation)
Args:
x (_type_): _description_
Returns:
_type_: _description_
"""
preds:List = self(x)
# preds_nms = [apply_nms(i, self.nms_thr) for i in preds]
return preds
def training_step(self, batch, batch_idx):
img_b, target_b = self.prepare_batch(batch)
bs = len(img_b)
dfine_output = self(img_b, target_b)
self.log_dict(
dfine_output.loss_dict,
on_step=False,
on_epoch=True,
sync_dist=True,
batch_size=bs,
prog_bar=True,
)
return {"loss": dfine_output.loss}
def validation_step(self, batch, batch_idx):
img_b, target_b = self.prepare_batch(batch)
output_nms = self.predict(img_b)
self.val_metrics(output_nms, target_b)
return
def on_validation_epoch_end(self):
m = self.val_metrics.compute()
m_single= {i:j for i,j in m.items() if j.nelement() ==1}
self.log_dict(m_single, on_epoch=True, sync_dist=False)
for i, class_id in enumerate(m["Validation/classes"]):
self.log(f"Validation/MAP {self.categories[class_id]}", m["Validation/map_per_class"][i])
self.val_metrics.reset()
return
def test_step(self, batch, batch_idx):
img_b, target_b = self.prepare_batch(batch)
output_nms = self.predict(img_b)
self.test_metrics(output_nms, target_b)
return
def on_test_epoch_end(self):
m = self.test_metrics.compute()
m_single= {i:j for i,j in m.items() if j.nelement() ==1}
self.log_dict(m_single, on_epoch=True, sync_dist=False)
for i, class_id in enumerate(m["Test/classes"]):
self.log(f"Test/MAP {self.categories[class_id]}", m["Test/map_per_class"][i])
self.test_metrics.reset()
return
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.config.lr,
weight_decay=1e-4 * self.config.batch_size / 16,
)
# scheduler1 = torch.optim.lr_scheduler.CosineAnnealingLR(
# optimizer, T_max=scheduler_nsteps, eta_min=self.config.lr / 10
# )
# sched_config1 = {"scheduler": scheduler1, "interval": "epoch"}
return [optimizer]#, [sched_config1]
model = odModule(
config,
categories
)def collate_fn(batch):
# Separate the images and targets
images = []
targets = []
for image, target in batch:
images.append(image) # Assuming each item has an 'image' key
targets.append(target) # Assuming each item has a 'target' key
# Stack images into a single tensor
images = torch.stack(images, dim=0)
# Convert targets to a list of dictionaries or tensors
# This depends on how your targets are structured
# For example, if targets are dictionaries with bounding boxes and labels
return images, targets
train_loader = DataLoader(
ds,
batch_size = 2,
shuffle = True,
collate_fn = collate_fn
)
val_loader = DataLoader(
val_ds,
batch_size = 2,
shuffle = False,
collate_fn = collate_fn
)
im, tar = next(iter(train_loader))
tar[{'class_labels': tensor([9]),
'labels': tensor([9]),
'boxes': tensor([[ 84.5600, 82.3200, 85.1200, 125.4400]])},
{'class_labels': tensor([17, 17, 4]),
'labels': tensor([17, 17, 4]),
'boxes': tensor([[ 0.0000, 66.6400, 196.0000, 157.3600],
[ 0.0000, 117.6000, 124.3200, 106.4000],
[129.3600, 168.5600, 22.4000, 24.6400]])}]
Training
All the training is shown here. The full training is happening here.
the full training has not been done yet but no bug has been encountered in first batches.
model.train()
trainer= L.Trainer(
max_epochs=3,
max_steps = 10, # limit training to make sure it works.
precision = "16-mixed",
enable_checkpointing=True,
num_sanity_val_steps=0,
log_every_n_steps=50,
check_val_every_n_epoch=1,
)
trainer.fit(model, train_loader, val_loader)Using 16bit Automatic Mixed Precision (AMP)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
/mnt/big_partition/projet/JulienCombes/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/mnt/big_partition/projet/JulienCombes/.venv/lib/python3.13/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:231: Precision 16-mixed is not supported by the model summary. Estimated model size in MB will not be accurate. Using 32 bits instead.
| Name | Type | Params | Mode
-----------------------------------------------------------------
0 | model | DFineForObjectDetection | 62.5 M | train
1 | val_metrics | MetricCollection | 0 | train
2 | test_metrics | MetricCollection | 0 | train
-----------------------------------------------------------------
62.5 M Trainable params
2 Non-trainable params
62.5 M Total params
250.001 Total estimated model params size (MB)
1420 Modules in train mode
0 Modules in eval mode
/mnt/big_partition/projet/JulienCombes/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/mnt/big_partition/projet/JulienCombes/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_steps=10` reached.