Training API¶
API reference for training utilities.
Callbacks¶
Custom Lightning callbacks for object detection training.
ConfusionMatrixCallback¶
- class objdet.training.callbacks.confusion_matrix.ConfusionMatrixCallback(num_classes, iou_threshold=0.5, confidence_threshold=0.25, save_dir='outputs/confusion_matrices', class_names=None, normalize='true', save_format='png')[source]¶
Bases:
CallbackCallback to compute and save confusion matrix.
The confusion matrix shows true vs predicted classes for all detections matched using IoU threshold.
- Parameters:
num_classes (
int) – Number of object classes (not including background).iou_threshold (
float) – IoU threshold for matching predictions to ground truth.confidence_threshold (
float) – Minimum confidence for predictions.save_dir (
str|Path) – Directory to save confusion matrix plots.class_names (
list[str] |None) – Optional list of class names for axis labels.normalize (
str|None) – How to normalize - “true”, “pred”, “all”, or None.save_format (
str) – File format for saving (“png”, “pdf”, “svg”).
Example
>>> callback = ConfusionMatrixCallback( ... num_classes=80, ... save_dir="outputs/confusion_matrices", ... class_names=["person", "car", ...], ... ) >>> trainer = Trainer(callbacks=[callback])
- on_validation_epoch_start(trainer, pl_module)[source]¶
Reset confusion matrix at start of validation.
- Return type:
Generates and saves confusion matrix visualizations during validation.
from objdet.training import ConfusionMatrixCallback
from lightning import Trainer
trainer = Trainer(
callbacks=[
ConfusionMatrixCallback(
num_classes=80,
class_names=class_names,
save_dir="./confusion_matrices",
),
],
)
DetectionVisualizationCallback¶
- class objdet.training.callbacks.visualization.DetectionVisualizationCallback(num_samples=8, save_dir='outputs/visualizations', log_to_tensorboard=True, confidence_threshold=0.5, class_names=None, box_color=(0, 255, 0))[source]¶
Bases:
CallbackCallback to visualize detection predictions on sample images.
- Parameters:
num_samples (
int) – Number of samples to visualize per epoch.save_dir (
str|Path) – Directory to save visualization images.log_to_tensorboard (
bool) – Whether to also log to TensorBoard.confidence_threshold (
float) – Minimum confidence for visualization.class_names (
list[str] |None) – Optional list of class names for labels.box_color (
tuple[int,int,int] |str) – Color for prediction boxes (BGR tuple or “random”).
Example
>>> callback = DetectionVisualizationCallback( ... num_samples=8, ... save_dir="outputs/visualizations", ... class_names=["person", "car", "dog"], ... )
Visualizes detection predictions on sample images during training.
from objdet.training import DetectionVisualizationCallback
callback = DetectionVisualizationCallback(
num_samples=8,
score_threshold=0.5,
class_names=class_names,
)
GradientMonitorCallback¶
- class objdet.training.callbacks.gradient_monitor.GradientMonitorCallback(log_every_n_steps=100, detect_anomalies=False)[source]¶
Bases:
CallbackCallback to monitor gradient statistics.
Logs gradient norms, min/max values, and NaN/Inf detection for debugging training stability issues.
- Parameters:
Example
>>> callback = GradientMonitorCallback(log_every_n_steps=50) >>> trainer = Trainer(callbacks=[callback])
Monitors and logs gradient statistics during training.
from objdet.training import GradientMonitorCallback
callback = GradientMonitorCallback(
log_every_n_steps=100,
)
LearningRateMonitorCallback¶
- class objdet.training.callbacks.lr_monitor.LearningRateMonitorCallback(log_momentum=False, log_weight_decay=False)[source]¶
Bases:
CallbackCallback to monitor and log learning rates.
This extends Lightning’s built-in LearningRateMonitor with additional logging for multiple parameter groups.
- Parameters:
Example
>>> callback = LearningRateMonitorCallback(log_momentum=True) >>> trainer = Trainer(callbacks=[callback])
Enhanced learning rate monitoring with additional logging.
from objdet.training import LearningRateMonitorCallback
callback = LearningRateMonitorCallback()
Metrics¶
Custom metrics for object detection evaluation.
ClasswiseAP¶
- class objdet.training.metrics.classwise_ap.ClasswiseAP(num_classes, iou_thresholds=None, class_names=None, dist_sync_on_step=False)[source]¶
Bases:
MetricCompute class-wise Average Precision.
Wraps torchmetrics MeanAveragePrecision to provide easy access to per-class AP values.
- Parameters:
Example
>>> metric = ClasswiseAP(num_classes=80, class_names=COCO_CLASSES) >>> metric.update(predictions, targets) >>> ap_per_class = metric.compute()
Computes per-class Average Precision.
from objdet.training.metrics import ClasswiseAP
metric = ClasswiseAP(
num_classes=80,
class_names=class_names,
iou_threshold=0.5,
)
# Update with predictions and targets
metric.update(predictions, targets)
# Compute results
results = metric.compute()
# {"class_0_AP": 0.85, "class_1_AP": 0.72, ...}
ConfusionMatrix¶
- class objdet.training.metrics.confusion_matrix.ConfusionMatrix(num_classes, iou_threshold=0.5, confidence_threshold=0.25, dist_sync_on_step=False)[source]¶
Bases:
MetricConfusion matrix metric for object detection.
Unlike classification confusion matrices, detection requires matching predictions to ground truth using IoU threshold.
- Parameters:
Example
>>> metric = ConfusionMatrix(num_classes=80) >>> metric.update(predictions, targets) >>> cm = metric.compute()
Computes detection confusion matrix.
from objdet.training.metrics import ConfusionMatrix
metric = ConfusionMatrix(
num_classes=80,
iou_threshold=0.5,
conf_threshold=0.25,
)
metric.update(predictions, targets)
matrix = metric.compute() # Returns (num_classes+1, num_classes+1) tensor
Note
The matrix includes an extra row/column for background (false positives/negatives).