Training API

API reference for training utilities.

Callbacks

Custom Lightning callbacks for object detection training.

ConfusionMatrixCallback

class objdet.training.callbacks.confusion_matrix.ConfusionMatrixCallback(num_classes, iou_threshold=0.5, confidence_threshold=0.25, save_dir='outputs/confusion_matrices', class_names=None, normalize='true', save_format='png')[source]

Bases: Callback

Callback to compute and save confusion matrix.

The confusion matrix shows true vs predicted classes for all detections matched using IoU threshold.

Parameters:
  • num_classes (int) – Number of object classes (not including background).

  • iou_threshold (float) – IoU threshold for matching predictions to ground truth.

  • confidence_threshold (float) – Minimum confidence for predictions.

  • save_dir (str | Path) – Directory to save confusion matrix plots.

  • class_names (list[str] | None) – Optional list of class names for axis labels.

  • normalize (str | None) – How to normalize - “true”, “pred”, “all”, or None.

  • save_format (str) – File format for saving (“png”, “pdf”, “svg”).

Example

>>> callback = ConfusionMatrixCallback(
...     num_classes=80,
...     save_dir="outputs/confusion_matrices",
...     class_names=["person", "car", ...],
... )
>>> trainer = Trainer(callbacks=[callback])
on_validation_epoch_start(trainer, pl_module)[source]

Reset confusion matrix at start of validation.

Return type:

None

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Update confusion matrix with batch predictions.

Return type:

None

on_validation_epoch_end(trainer, pl_module)[source]

Save confusion matrix at end of validation.

Return type:

None

Generates and saves confusion matrix visualizations during validation.

from objdet.training import ConfusionMatrixCallback
from lightning import Trainer

trainer = Trainer(
    callbacks=[
        ConfusionMatrixCallback(
            num_classes=80,
            class_names=class_names,
            save_dir="./confusion_matrices",
        ),
    ],
)

DetectionVisualizationCallback

class objdet.training.callbacks.visualization.DetectionVisualizationCallback(num_samples=8, save_dir='outputs/visualizations', log_to_tensorboard=True, confidence_threshold=0.5, class_names=None, box_color=(0, 255, 0))[source]

Bases: Callback

Callback to visualize detection predictions on sample images.

Parameters:
  • num_samples (int) – Number of samples to visualize per epoch.

  • save_dir (str | Path) – Directory to save visualization images.

  • log_to_tensorboard (bool) – Whether to also log to TensorBoard.

  • confidence_threshold (float) – Minimum confidence for visualization.

  • class_names (list[str] | None) – Optional list of class names for labels.

  • box_color (tuple[int, int, int] | str) – Color for prediction boxes (BGR tuple or “random”).

Example

>>> callback = DetectionVisualizationCallback(
...     num_samples=8,
...     save_dir="outputs/visualizations",
...     class_names=["person", "car", "dog"],
... )
on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Collect samples for visualization.

Return type:

None

on_validation_epoch_end(trainer, pl_module)[source]

Create and save visualizations.

Return type:

None

Visualizes detection predictions on sample images during training.

from objdet.training import DetectionVisualizationCallback

callback = DetectionVisualizationCallback(
    num_samples=8,
    score_threshold=0.5,
    class_names=class_names,
)

GradientMonitorCallback

class objdet.training.callbacks.gradient_monitor.GradientMonitorCallback(log_every_n_steps=100, detect_anomalies=False)[source]

Bases: Callback

Callback to monitor gradient statistics.

Logs gradient norms, min/max values, and NaN/Inf detection for debugging training stability issues.

Parameters:
  • log_every_n_steps (int) – How often to log gradient stats.

  • detect_anomalies (bool) – Whether to enable anomaly detection.

Example

>>> callback = GradientMonitorCallback(log_every_n_steps=50)
>>> trainer = Trainer(callbacks=[callback])
on_before_optimizer_step(trainer, pl_module, optimizer)[source]

Monitor gradients before optimizer step.

Return type:

None

Monitors and logs gradient statistics during training.

from objdet.training import GradientMonitorCallback

callback = GradientMonitorCallback(
    log_every_n_steps=100,
)

LearningRateMonitorCallback

class objdet.training.callbacks.lr_monitor.LearningRateMonitorCallback(log_momentum=False, log_weight_decay=False)[source]

Bases: Callback

Callback to monitor and log learning rates.

This extends Lightning’s built-in LearningRateMonitor with additional logging for multiple parameter groups.

Parameters:
  • log_momentum (bool) – Whether to also log momentum values.

  • log_weight_decay (bool) – Whether to log weight decay values.

Example

>>> callback = LearningRateMonitorCallback(log_momentum=True)
>>> trainer = Trainer(callbacks=[callback])
on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]

Log learning rates at batch start.

Return type:

None

Enhanced learning rate monitoring with additional logging.

from objdet.training import LearningRateMonitorCallback

callback = LearningRateMonitorCallback()

Metrics

Custom metrics for object detection evaluation.

ClasswiseAP

class objdet.training.metrics.classwise_ap.ClasswiseAP(num_classes, iou_thresholds=None, class_names=None, dist_sync_on_step=False)[source]

Bases: Metric

Compute class-wise Average Precision.

Wraps torchmetrics MeanAveragePrecision to provide easy access to per-class AP values.

Parameters:
  • num_classes (int) – Number of object classes.

  • iou_thresholds (list[float] | None) – List of IoU thresholds.

  • class_names (list[str] | None) – Optional list of class names.

Example

>>> metric = ClasswiseAP(num_classes=80, class_names=COCO_CLASSES)
>>> metric.update(predictions, targets)
>>> ap_per_class = metric.compute()
is_differentiable: bool | None = False
higher_is_better: bool | None = True
full_state_update: bool | None = True
update(preds, targets)[source]

Update metric with batch of predictions.

Parameters:
Return type:

None

compute()[source]

Compute class-wise AP.

Return type:

dict[str, Tensor | dict[str, float]]

Returns:

Dictionary with overall metrics and per-class AP.

reset()[source]

Reset underlying metric.

Return type:

None

Computes per-class Average Precision.

from objdet.training.metrics import ClasswiseAP

metric = ClasswiseAP(
    num_classes=80,
    class_names=class_names,
    iou_threshold=0.5,
)

# Update with predictions and targets
metric.update(predictions, targets)

# Compute results
results = metric.compute()
# {"class_0_AP": 0.85, "class_1_AP": 0.72, ...}

ConfusionMatrix

class objdet.training.metrics.confusion_matrix.ConfusionMatrix(num_classes, iou_threshold=0.5, confidence_threshold=0.25, dist_sync_on_step=False)[source]

Bases: Metric

Confusion matrix metric for object detection.

Unlike classification confusion matrices, detection requires matching predictions to ground truth using IoU threshold.

Parameters:
  • num_classes (int) – Number of object classes.

  • iou_threshold (float) – IoU threshold for matching.

  • confidence_threshold (float) – Minimum prediction confidence.

  • dist_sync_on_step (bool) – Whether to sync on step (for DDP).

Example

>>> metric = ConfusionMatrix(num_classes=80)
>>> metric.update(predictions, targets)
>>> cm = metric.compute()
is_differentiable: bool | None = False
higher_is_better: bool | None = None
full_state_update: bool | None = False
matrix: Tensor
update(preds, targets)[source]

Update confusion matrix with batch of predictions.

Parameters:
Return type:

None

compute()[source]

Compute confusion matrix.

Return type:

Tensor

Returns:

Confusion matrix tensor of shape (num_classes+1, num_classes+1).

Computes detection confusion matrix.

from objdet.training.metrics import ConfusionMatrix

metric = ConfusionMatrix(
    num_classes=80,
    iou_threshold=0.5,
    conf_threshold=0.25,
)

metric.update(predictions, targets)
matrix = metric.compute()  # Returns (num_classes+1, num_classes+1) tensor

Note

The matrix includes an extra row/column for background (false positives/negatives).