Skip to content

Conversation

@theo-barfoot
Copy link

Description

Addresses #8505

Overview

This PR adds calibration error metrics and an Ignite handler to MONAI, enabling users to evaluate and monitor model calibration for segmentation and other multi-class probabilistic tasks with shape (B, C, spatial...).

What's Included

1. Calibration Metrics (monai/metrics/calibration.py)

  • calibration_binning(): Core function to compute calibration bins with mean predictions, mean ground truths, and bin counts. Exported to support research workflows where users need per-bin statistics for plotting reliability diagrams.
  • CalibrationReduction: Enum supporting three reduction methods:
    • EXPECTED - Expected Calibration Error (ECE): weighted average by bin count
    • AVERAGE - Average Calibration Error (ACE): simple average across bins
    • MAXIMUM - Maximum Calibration Error (MCE): maximum error across bins
  • CalibrationErrorMetric: A CumulativeIterationMetric subclass supporting:
    • Configurable number of bins
    • Background channel exclusion (include_background)
    • All standard MONAI metric reductions (mean, sum, mean_batch, etc.)
    • Batched, per-class computation

2. Ignite Handler (monai/handlers/calibration.py)

  • CalibrationError: An IgniteMetricHandler wrapper that:
    • Attaches to PyTorch Ignite engines for training/validation loops
    • Supports save_details for per-sample/per-channel metric details via the metric buffer
    • Integrates with MONAI's existing handler ecosystem

3. Comprehensive Tests

  • tests/metrics/test_calibration_metric.py: Tests covering:
    • Binning function correctness with NaN handling
    • ECE/ACE/MCE reduction modes
    • Background exclusion
    • Cumulative iteration behavior
    • Input validation (shape mismatch, ndim, num_bins)
  • tests/handlers/test_handler_calibration_error.py: Tests covering:
    • Handler attachment and computation via engine.run()
    • All calibration reduction modes
    • save_details functionality
    • Optional Ignite dependency handling (tests skip if Ignite not installed)

Public API

Exposes the following via monai.metrics:

  • CalibrationErrorMetric
  • CalibrationReduction
  • calibration_binning

Exposes via monai.handlers:

  • CalibrationError

Implementation Notes

  • Uses scatter_add + counts instead of scatter_reduce("mean") for better PyTorch version compatibility
  • Includes input validation with clear error messages
  • Clamps bin indices to prevent out-of-range errors with slightly out-of-bound probabilities
  • Uses torch.nan_to_num instead of in-place operations for cleaner code
  • Ignite is treated as an optional dependency in tests (skipped if not installed)

Related Work

The algorithmic approach follows the calibration metrics from Average-Calibration-Losses, with related publications:

Future Work

As discussed in the issue, calibration losses will be added in a separate PR to keep changes focused and easier to review.

Checklist

  • Code follows MONAI style guidelines (ruff passes)
  • All new code has appropriate license headers
  • Public API is exported in __init__.py files
  • Docstrings include examples with proper transforms usage
  • Unit tests cover main functionality
  • Tests handle optional Ignite dependency gracefully
  • No breaking changes to existing API

Example Usage

from monai.metrics import CalibrationErrorMetric
from monai.transforms import Activations, AsDiscrete

# Setup transforms
softmax = Activations(softmax=True)
to_onehot = AsDiscrete(to_onehot=num_classes)

# Create metric
metric = CalibrationErrorMetric(
    num_bins=15,
    include_background=False,
    calibration_reduction="expected"  # ECE
)

# In evaluation loop
# Note: y_pred should be probabilities in [0,1], y should be one-hot/binarized
for batch_data in dataloader:
    logits, labels = model(batch_data)
    preds = softmax(logits)
    labels_onehot = to_onehot(labels)
    metric(y_pred=preds, y=labels_onehot)

ece = metric.aggregate()

With Ignite Handler

from monai.handlers import CalibrationError, from_engine

calibration_handler = CalibrationError(
    num_bins=15,
    include_background=False,
    calibration_reduction="expected",
    output_transform=from_engine(["pred", "label"]),
    save_details=True,
)
calibration_handler.attach(evaluator, name="calibration_error")

- Add calibration_binning() function for hard binning calibration
- Add CalibrationErrorMetric with ECE/ACE/MCE reduction modes
- Add CalibrationError Ignite handler
- Add comprehensive tests for metrics and handler

Addresses Project-MONAI#8505

Signed-off-by: Theo Barfoot <[email protected]>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 16, 2026

📝 Walkthrough

Walkthrough

This pull request introduces a new calibration error metric system for MONAI. A new metrics module (monai/metrics/calibration.py) implements the core functionality: a calibration_binning function that computes per-bin statistics, a CalibrationReduction enum for reduction strategies (EXPECTED, AVERAGE, MAXIMUM), and a CalibrationErrorMetric class that extends CumulativeIterationMetric. A corresponding handler module (monai/handlers/calibration.py) wraps this metric for PyTorch Ignite integration. Public API exports are added to both monai/metrics/__init__.py and monai/handlers/__init__.py. Comprehensive test coverage is provided for both the metric and handler implementations.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.91% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly summarizes the main change: adding CalibrationErrorMetric and CalibrationError handler to MONAI.
Description check ✅ Passed Description comprehensively covers the changes, includes checklist items, examples, and context. Follows template structure with detailed overview of additions.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@monai/metrics/calibration.py`:
- Around line 228-235: In the CalibrationReduction.MAXIMUM branch, don’t convert
NaN to 0 (which hides “no data”); instead use a -inf sentinel when calling
torch.nan_to_num on abs_diff (e.g. nan=-torch.inf), take the max along dim=-1,
then detect buckets that were all-NaN (e.g. all_nan_mask =
torch.isnan(abs_diff).all(dim=-1)) and restore those positions in the result to
NaN; update the method where self.calibration_reduction is checked (the MAXIMUM
branch that uses abs_diff_no_nan) accordingly and add a unit test covering the
“all bins empty” case to prevent regressions.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 57fdd59 and 202b25f.

📒 Files selected for processing (6)
  • monai/handlers/__init__.py
  • monai/handlers/calibration.py
  • monai/metrics/__init__.py
  • monai/metrics/calibration.py
  • tests/handlers/test_handler_calibration_error.py
  • tests/metrics/test_calibration_metric.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/metrics/__init__.py
  • monai/handlers/__init__.py
  • monai/handlers/calibration.py
  • tests/handlers/test_handler_calibration_error.py
  • monai/metrics/calibration.py
  • tests/metrics/test_calibration_metric.py
🧬 Code graph analysis (6)
monai/metrics/__init__.py (1)
monai/metrics/calibration.py (3)
  • CalibrationErrorMetric (139-260)
  • CalibrationReduction (125-136)
  • calibration_binning (30-122)
monai/handlers/__init__.py (1)
monai/handlers/calibration.py (1)
  • CalibrationError (23-71)
monai/handlers/calibration.py (1)
monai/utils/enums.py (1)
  • MetricReduction (239-250)
tests/handlers/test_handler_calibration_error.py (4)
monai/handlers/calibration.py (1)
  • CalibrationError (23-71)
monai/handlers/utils.py (1)
  • from_engine (170-210)
monai/utils/module.py (2)
  • min_version (273-285)
  • optional_import (315-445)
tests/test_utils.py (1)
  • assert_allclose (119-159)
monai/metrics/calibration.py (4)
monai/metrics/metric.py (1)
  • CumulativeIterationMetric (296-353)
monai/metrics/utils.py (2)
  • do_metric_reduction (71-130)
  • ignore_background (54-68)
monai/utils/enums.py (2)
  • MetricReduction (239-250)
  • StrEnum (68-90)
monai/utils/profiling.py (1)
  • end (430-432)
tests/metrics/test_calibration_metric.py (3)
monai/metrics/calibration.py (4)
  • CalibrationErrorMetric (139-260)
  • CalibrationReduction (125-136)
  • calibration_binning (30-122)
  • aggregate (239-260)
monai/utils/enums.py (1)
  • MetricReduction (239-250)
monai/metrics/metric.py (1)
  • get_buffer (282-293)
🪛 Ruff (0.14.11)
tests/handlers/test_handler_calibration_error.py

106-106: Unused function argument: engine

(ARG001)


142-142: Unused function argument: engine

(ARG001)


168-168: Unused function argument: engine

(ARG001)

monai/metrics/calibration.py

23-27: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)


71-71: Avoid specifying long messages outside the exception class

(TRY003)


73-73: Avoid specifying long messages outside the exception class

(TRY003)


75-75: Avoid specifying long messages outside the exception class

(TRY003)


204-204: Unused method argument: kwargs

(ARG002)


237-237: Avoid specifying long messages outside the exception class

(TRY003)


256-256: Prefer TypeError exception for invalid type

(TRY004)


256-256: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (ubuntu-latest)
🔇 Additional comments (14)
monai/handlers/__init__.py (1)

15-15: LGTM!

Import is correctly placed alphabetically and aligns with the new CalibrationError handler in monai/handlers/calibration.py.

monai/metrics/__init__.py (1)

16-16: LGTM!

Public API exports correctly added for the new calibration functionality.

tests/metrics/test_calibration_metric.py (4)

14-25: LGTM!

Imports and device setup are appropriate. Good use of _device for CUDA/CPU portability.


142-196: LGTM!

Comprehensive binning tests with proper NaN handling and input validation coverage.


199-253: LGTM!

All three calibration reduction modes tested with proper isolation from metric reduction.


256-354: LGTM!

Good coverage of metric options including include_background, metric_reduction, get_not_nans, cumulative behavior, and reset.

tests/handlers/test_handler_calibration_error.py (3)

19-25: LGTM!

Proper optional import pattern for Ignite with version check and skip decorator.


82-122: LGTM!

Handler tests properly verify metric computation and details shape. The unused engine parameter in _val_func is required by Ignite's callback signature.


124-181: LGTM!

Edge case tests cover single iteration and save_details=False behavior with appropriate defensive checks.

monai/handlers/calibration.py (1)

23-71: LGTM!

Clean handler implementation following MONAI patterns. Docstring adequately documents all parameters. Consider adding a usage example similar to other handlers if desired.

monai/metrics/calibration.py (4)

30-122: calibration_binning looks solid

Validation, binning, and empty-bin NaN handling are clear and consistent with the stated contract.


125-136: Enum values are clear

Naming and values match expected calibration reduction modes.


187-203: Init wiring looks good

Config is stored cleanly and defaults are sensible.


239-260: Aggregate logic is clean

Reduction and get_not_nans behavior are consistent with MONAI patterns.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Comment on lines +228 to +235
if self.calibration_reduction == CalibrationReduction.EXPECTED:
# Calculate the weighted sum of absolute differences
return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1)
elif self.calibration_reduction == CalibrationReduction.AVERAGE:
return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan
elif self.calibration_reduction == CalibrationReduction.MAXIMUM:
abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0)
return torch.max(abs_diff_no_nan, dim=-1).values # Maximum across all dimensions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "calibration.py" | grep metrics

Repository: Project-MONAI/MONAI

Length of output: 93


🏁 Script executed:

cat -n monai/metrics/calibration.py | head -250 | tail -50

Repository: Project-MONAI/MONAI

Length of output: 2965


🏁 Script executed:

rg "CalibrationReduction" --type py -B 2 -A 2

Repository: Project-MONAI/MONAI

Length of output: 6926


🏁 Script executed:

fd "test.*calibration" --type f

Repository: Project-MONAI/MONAI

Length of output: 152


🏁 Script executed:

cat tests/metrics/test_calibration_metric.py

Repository: Project-MONAI/MONAI

Length of output: 12660


🏁 Script executed:

rg "calibration_binning" -B 5 -A 10 monai/metrics/calibration.py | head -80

Repository: Project-MONAI/MONAI

Length of output: 1757


🏁 Script executed:

python3 << 'EOF'
import torch

# Test torch.max behavior with all NaN tensor
all_nan = torch.tensor([[float('nan'), float('nan')], [float('nan'), float('nan')]])
print("All NaN tensor:", all_nan)
print("torch.max(all_nan, dim=-1):", torch.max(all_nan, dim=-1))

# Test nan_to_num with nan=0.0
nan_to_zero = torch.nan_to_num(all_nan, nan=0.0)
print("\nnan_to_num with nan=0.0:", nan_to_zero)
print("torch.max(nan_to_zero, dim=-1):", torch.max(nan_to_zero, dim=-1))

# Test nan_to_num with nan=-inf
nan_to_neginf = torch.nan_to_num(all_nan, nan=-float('inf'))
print("\nnan_to_num with nan=-inf:", nan_to_neginf)
print("torch.max(nan_to_neginf, dim=-1):", torch.max(nan_to_neginf, dim=-1))

# Test torch.nanmax
print("\ntorch.nanmax(all_nan, dim=-1):", torch.nanmax(all_nan, dim=-1))

# Test comparison for restoration
max_vals = torch.max(nan_to_neginf, dim=-1).values
print("\nmax_vals after max(nan_to_neginf):", max_vals)
print("max_vals == -inf:", max_vals == -float('inf'))
max_vals[max_vals == -float('inf')] = float('nan')
print("After restoring NaN:", max_vals)

EOF

Repository: Project-MONAI/MONAI

Length of output: 149


🏁 Script executed:

python3 << 'EOF'
import torch

# Test 1: torch.max behavior with all NaN tensor
print("=== Test 1: All NaN tensor ===")
all_nan = torch.tensor([[float('nan'), float('nan')], [float('nan'), float('nan')]])
print("Input:", all_nan)
result = torch.max(all_nan, dim=-1)
print("torch.max(all_nan, dim=-1).values:", result.values)
print()

# Test 2: nan_to_num with nan=0.0
print("=== Test 2: nan_to_num with nan=0.0 ===")
nan_to_zero = torch.nan_to_num(all_nan, nan=0.0)
print("After nan_to_num(nan=0.0):", nan_to_zero)
print("torch.max result:", torch.max(nan_to_zero, dim=-1).values)
print()

# Test 3: nan_to_num with nan=-inf
print("=== Test 3: nan_to_num with nan=-inf ===")
nan_to_neginf = torch.nan_to_num(all_nan, nan=-float('inf'))
print("After nan_to_num(nan=-inf):", nan_to_neginf)
max_result = torch.max(nan_to_neginf, dim=-1).values
print("torch.max result:", max_result)
print()

# Test 4: Restore NaN after max
print("=== Test 4: Restore NaN ===")
max_vals = max_result.clone()
print("Original max_vals:", max_vals)
max_vals[max_vals == -float('inf')] = float('nan')
print("After restoring NaN:", max_vals)
print()

# Test 5: Check if torch.nanmax exists and works
print("=== Test 5: torch.nanmax ===")
try:
    result = torch.nanmax(all_nan, dim=-1)
    print("torch.nanmax(all_nan, dim=-1):", result)
except Exception as e:
    print("Error:", e)

EOF

Repository: Project-MONAI/MONAI

Length of output: 180


🏁 Script executed:

rg "all.*nan|empty.*bin|nan.*maximum" tests/metrics/test_calibration_metric.py -i

Repository: Project-MONAI/MONAI

Length of output: 45


🏁 Script executed:

sed -n '1,150p' monai/metrics/calibration.py | tail -80

Repository: Project-MONAI/MONAI

Length of output: 3534


Use sentinel value to preserve NaN when all bins are empty in MAXIMUM reduction

When all bins are empty (all NaN), nan_to_num(..., nan=0.0) converts NaN values to 0, causing MCE to return 0 instead of NaN. This misrepresents "no valid data" as "zero error". Replace with -inf as sentinel and restore NaN for all-NaN cases:

Suggested fix
        elif self.calibration_reduction == CalibrationReduction.MAXIMUM:
-            abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0)
-            return torch.max(abs_diff_no_nan, dim=-1).values  # Maximum across all dimensions
+            abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=-torch.inf)
+            max_vals = torch.max(abs_diff_no_nan, dim=-1).values
+            max_vals[max_vals == -torch.inf] = torch.nan
+            return max_vals  # Maximum across valid bins

Additionally, add a test case for the all-empty-bins edge case to prevent regression.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.calibration_reduction == CalibrationReduction.EXPECTED:
# Calculate the weighted sum of absolute differences
return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1)
elif self.calibration_reduction == CalibrationReduction.AVERAGE:
return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan
elif self.calibration_reduction == CalibrationReduction.MAXIMUM:
abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0)
return torch.max(abs_diff_no_nan, dim=-1).values # Maximum across all dimensions
if self.calibration_reduction == CalibrationReduction.EXPECTED:
# Calculate the weighted sum of absolute differences
return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1)
elif self.calibration_reduction == CalibrationReduction.AVERAGE:
return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan
elif self.calibration_reduction == CalibrationReduction.MAXIMUM:
abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=-torch.inf)
max_vals = torch.max(abs_diff_no_nan, dim=-1).values
max_vals[max_vals == -torch.inf] = torch.nan
return max_vals # Maximum across valid bins
🤖 Prompt for AI Agents
In `@monai/metrics/calibration.py` around lines 228 - 235, In the
CalibrationReduction.MAXIMUM branch, don’t convert NaN to 0 (which hides “no
data”); instead use a -inf sentinel when calling torch.nan_to_num on abs_diff
(e.g. nan=-torch.inf), take the max along dim=-1, then detect buckets that were
all-NaN (e.g. all_nan_mask = torch.isnan(abs_diff).all(dim=-1)) and restore
those positions in the result to NaN; update the method where
self.calibration_reduction is checked (the MAXIMUM branch that uses
abs_diff_no_nan) accordingly and add a unit test covering the “all bins empty”
case to prevent regressions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant