Source code for objdet.training.callbacks.gradient_monitor

"""Gradient monitoring callback.

This callback monitors gradient statistics during training,
useful for debugging training issues.
"""

from __future__ import annotations

import lightning as L
import torch
from lightning.pytorch.callbacks import Callback
from torch.optim import Optimizer

from objdet.core.logging import get_logger

logger = get_logger(__name__)


[docs] class GradientMonitorCallback(Callback): """Callback to monitor gradient statistics. Logs gradient norms, min/max values, and NaN/Inf detection for debugging training stability issues. Args: log_every_n_steps: How often to log gradient stats. detect_anomalies: Whether to enable anomaly detection. Example: >>> callback = GradientMonitorCallback(log_every_n_steps=50) >>> trainer = Trainer(callbacks=[callback]) """ def __init__( self, log_every_n_steps: int = 100, detect_anomalies: bool = False, ) -> None: super().__init__() self.log_every_n_steps = log_every_n_steps self.detect_anomalies = detect_anomalies
[docs] def on_before_optimizer_step( self, trainer: L.Trainer, pl_module: L.LightningModule, optimizer: Optimizer, ) -> None: """Monitor gradients before optimizer step.""" grad_norms, grad_max, has_nan, has_inf = self._process_gradients(pl_module) if grad_norms: total_norm = torch.tensor(grad_norms).norm(2).item() max_grad = max(grad_max) mean_norm = sum(grad_norms) / len(grad_norms) # Log statistics pl_module.log("train/grad_norm", total_norm, prog_bar=False) pl_module.log("train/grad_max", max_grad, prog_bar=False) pl_module.log("train/grad_mean_norm", mean_norm, prog_bar=False) if has_nan: pl_module.log("train/grad_has_nan", 1.0, prog_bar=False) if has_inf: pl_module.log("train/grad_has_inf", 1.0, prog_bar=False)
def _process_gradients( self, pl_module: L.LightningModule ) -> tuple[list[float], list[float], bool, bool]: """Process gradients and collect statistics.""" grad_norms = [] grad_max = [] has_nan = False has_inf = False for name, param in pl_module.named_parameters(): if param.grad is None: continue grad = param.grad.data is_nan, is_inf = self._check_anomalies(grad, name) has_nan = has_nan or is_nan has_inf = has_inf or is_inf # Collect statistics grad_norms.append(grad.norm(2).item()) grad_max.append(grad.abs().max().item()) return grad_norms, grad_max, has_nan, has_inf def _check_anomalies(self, grad: torch.Tensor, name: str) -> tuple[bool, bool]: """Check for NaN and Inf in gradients.""" is_nan = bool(grad.isnan().any()) is_inf = bool(grad.isinf().any()) if self.detect_anomalies: if is_nan: logger.warning(f"NaN gradient detected in {name}") if is_inf: logger.warning(f"Inf gradient detected in {name}") return is_nan, is_inf