8888 [torch .as_tensor ([[0.0 , 1.0 , 0.0 ], [0.6667 , 0.0 , 0.4 ]]), torch .as_tensor ([[0.0 , 0.5 , 0.0 ], [0.3333 , 0.0 , 0.4 ]])],
8989]
9090
91+ # 3D test cases
92+ sample_3d_pred = torch .as_tensor (
93+ [
94+ [
95+ [[[2 , 0 ], [1 , 1 ]], [[0 , 1 ], [2 , 1 ]]], # instance channel
96+ [[[0 , 1 ], [3 , 0 ]], [[1 , 0 ], [1 , 1 ]]], # class channel
97+ ]
98+ ],
99+ device = _device ,
100+ )
101+
102+ sample_3d_gt = torch .as_tensor (
103+ [
104+ [
105+ [[[2 , 0 ], [0 , 0 ]], [[2 , 2 ], [2 , 3 ]]], # instance channel
106+ [[[3 , 3 ], [3 , 2 ]], [[2 , 2 ], [3 , 3 ]]], # class channel
107+ ]
108+ ],
109+ device = _device ,
110+ )
111+
112+ # test 3D sample, num_classes = 3, match_iou_threshold = 0.5
113+ TEST_3D_CASE_1 = [{"num_classes" : 3 , "match_iou_threshold" : 0.5 }, sample_3d_pred , sample_3d_gt ]
114+
115+ # test confusion matrix return
116+ TEST_CM_CASE_1 = [
117+ {"num_classes" : 3 , "match_iou_threshold" : 0.5 , "return_confusion_matrix" : True },
118+ sample_3_pred ,
119+ sample_3_gt ,
120+ ]
121+
91122
92123@SkipIfNoModule ("scipy.optimize" )
93124class TestPanopticQualityMetric (unittest .TestCase ):
@@ -108,6 +139,98 @@ def test_value_class(self, input_params, y_pred, y_gt, expected_value):
108139 else :
109140 np .testing .assert_allclose (outputs .cpu ().numpy (), np .asarray (expected_value ), atol = 1e-4 )
110141
142+ def test_3d_support (self ):
143+ """Test that 3D input is properly supported."""
144+ input_params , y_pred , y_gt = TEST_3D_CASE_1
145+ metric = PanopticQualityMetric (** input_params )
146+ # Should not raise an error for 3D input
147+ metric (y_pred , y_gt )
148+ outputs = metric .aggregate ()
149+ # Check that output is a tensor
150+ self .assertIsInstance (outputs , torch .Tensor )
151+ # Check that output shape is correct (num_classes,)
152+ self .assertEqual (outputs .shape , torch .Size ([3 ]))
153+
154+ def test_confusion_matrix_return (self ):
155+ """Test that confusion matrix can be returned instead of computed metrics."""
156+ input_params , y_pred , y_gt = TEST_CM_CASE_1
157+ metric = PanopticQualityMetric (** input_params )
158+ metric (y_pred , y_gt )
159+ outputs = metric .aggregate ()
160+ # Check that output is a tensor with shape (batch_size, num_classes, 4)
161+ self .assertIsInstance (outputs , torch .Tensor )
162+ self .assertEqual (outputs .shape [- 1 ], 4 )
163+ # Verify that values correspond to [tp, fp, fn, iou_sum]
164+ tp , fp , fn , iou_sum = outputs [..., 0 ], outputs [..., 1 ], outputs [..., 2 ], outputs [..., 3 ]
165+ # tp, fp, fn should be non-negative integers
166+ self .assertTrue (torch .all (tp >= 0 ))
167+ self .assertTrue (torch .all (fp >= 0 ))
168+ self .assertTrue (torch .all (fn >= 0 ))
169+ # iou_sum should be non-negative float
170+ self .assertTrue (torch .all (iou_sum >= 0 ))
171+
172+ def test_compute_mean_iou (self ):
173+ """Test mean IoU computation from confusion matrix."""
174+ from monai .metrics .panoptic_quality import compute_mean_iou
175+
176+ input_params , y_pred , y_gt = TEST_CM_CASE_1
177+ metric = PanopticQualityMetric (** input_params )
178+ metric (y_pred , y_gt )
179+ confusion_matrix = metric .aggregate ()
180+ mean_iou = compute_mean_iou (confusion_matrix )
181+ # Check shape is correct
182+ self .assertEqual (mean_iou .shape , confusion_matrix .shape [:- 1 ])
183+ # Check values are non-negative
184+ self .assertTrue (torch .all (mean_iou >= 0 ))
185+
186+ def test_metric_name_filtering (self ):
187+ """Test that metric_name parameter properly filters output."""
188+ # Test single metric "sq"
189+ metric_sq = PanopticQualityMetric (num_classes = 3 , metric_name = "sq" , match_iou_threshold = 0.5 )
190+ metric_sq (sample_3_pred , sample_3_gt )
191+ result_sq = metric_sq .aggregate ()
192+ self .assertIsInstance (result_sq , torch .Tensor )
193+ self .assertEqual (result_sq .shape , torch .Size ([3 ]))
194+
195+ # Test single metric "rq"
196+ metric_rq = PanopticQualityMetric (num_classes = 3 , metric_name = "rq" , match_iou_threshold = 0.5 )
197+ metric_rq (sample_3_pred , sample_3_gt )
198+ result_rq = metric_rq .aggregate ()
199+ self .assertIsInstance (result_rq , torch .Tensor )
200+ self .assertEqual (result_rq .shape , torch .Size ([3 ]))
201+
202+ # Results should be different for different metrics
203+ self .assertFalse (torch .allclose (result_sq , result_rq , atol = 1e-4 ))
204+
205+ def test_invalid_3d_shape (self ):
206+ """Test that invalid 3D shapes are rejected."""
207+ # Shape with 3 dimensions should fail
208+ invalid_pred = torch .randint (0 , 5 , (2 , 2 , 10 ))
209+ invalid_gt = torch .randint (0 , 5 , (2 , 2 , 10 ))
210+ metric = PanopticQualityMetric (num_classes = 3 )
211+ with self .assertRaises (ValueError ):
212+ metric (invalid_pred , invalid_gt )
213+
214+ # Shape with 6 dimensions should fail
215+ invalid_pred = torch .randint (0 , 5 , (1 , 2 , 8 , 8 , 8 , 8 ))
216+ invalid_gt = torch .randint (0 , 5 , (1 , 2 , 8 , 8 , 8 , 8 ))
217+ with self .assertRaises (ValueError ):
218+ metric (invalid_pred , invalid_gt )
219+
220+ def test_compute_mean_iou_invalid_shape (self ):
221+ """Test that compute_mean_iou raises ValueError for invalid shapes."""
222+ from monai .metrics .panoptic_quality import compute_mean_iou
223+
224+ # Shape (..., 3) instead of (..., 4) should fail
225+ invalid_confusion_matrix = torch .zeros (3 , 3 )
226+ with self .assertRaises (ValueError ):
227+ compute_mean_iou (invalid_confusion_matrix )
228+
229+ # Shape (..., 5) should also fail
230+ invalid_confusion_matrix = torch .zeros (2 , 5 )
231+ with self .assertRaises (ValueError ):
232+ compute_mean_iou (invalid_confusion_matrix )
233+
111234
112235if __name__ == "__main__" :
113236 unittest .main ()
0 commit comments