@@ -145,7 +145,7 @@ def _forward_onepass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
145145 This method uses uniform noise to roughly model quantization.
146146 """
147147 y_hat = self .quantize (y )
148- y_ctx = self ._keep_only (self .context_prediction (y_hat ), "non_anchor" )
148+ y_ctx = self ._mask_all_but_step (self .context_prediction (y_hat ), "non_anchor" )
149149 params = self .entropy_parameters (self .merge (y_ctx , side_params ))
150150 y_out = self .latent_codec ["y" ](y , params )
151151 return {
@@ -167,39 +167,28 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
167167 To compute ``y_hat_anchors``, we need the predicted ``means_hat``:
168168 ``y_hat = quantize_ste(y - means_hat) + means_hat``.
169169 Thus, two passes of ``entropy_parameters`` are necessary.
170-
171170 """
172171 B , C , H , W = y .shape
173-
174172 params = y .new_zeros ((B , C * 2 , H , W ))
175-
176173 y_hat_ = []
177174
178- # NOTE: The _i variables contain only the current step's pixels.
179- # i=0: step=anchor
180- # i=1: step=non_anchor
181-
182175 for step in ("anchor" , "non_anchor" ):
176+ # Determine y_ctx for current step.
183177 if step == "anchor" :
184- y_ctx = self ._y_ctx_zero (y )
178+ y_ctx_i = self ._y_ctx_zero (y )
185179 else : # step == "non_anchor"
186- y_hat_anchors = y_hat_ [0 ]
187- y_ctx = self .context_prediction (y_hat_anchors )
180+ y_ctx_i = self .context_prediction (y_hat_ [0 ])
188181
189- params_i = self . entropy_parameters ( self . merge ( y_ctx , side_params ))
190-
191- # Save params for current step. This is later used for entropy estimation.
182+ # Determine params for current step.
183+ params_i = self . entropy_parameters ( self . merge ( y_ctx_i , side_params ))
184+ params_i = self . _mask_all_but_step ( params_i , step )
192185 self ._copy (params , params_i , step )
193186
194- # Keep only elements needed for current step.
195- # It's not necessary to mask the rest out just yet, but it doesn't hurt.
196- params_i = self ._keep_only (params_i , step )
197- y_i = self ._keep_only (y , step )
198-
199- # Determine y_hat for current step, and mask out the other pixels.
187+ # Determine y_hat for current step.
200188 _ , means_i = self .latent_codec ["y" ]._chunk (params_i )
201- y_hat_i = self ._keep_only (quantize_ste (y_i - means_i ) + means_i , step )
202-
189+ y_i = self ._mask_all_but_step (y , step )
190+ y_hat_i = quantize_ste (y_i - means_i ) + means_i
191+ y_hat_i = self ._mask_all_but_step (y_hat_i , step )
203192 y_hat_ .append (y_hat_i )
204193
205194 [y_hat_anchors , y_hat_non_anchors ] = y_hat_
@@ -224,13 +213,13 @@ def _forward_twopass_faster(self, y: Tensor, side_params: Tensor) -> Dict[str, A
224213 """
225214 y_ctx = self ._y_ctx_zero (y )
226215 params = self .entropy_parameters (self .merge (y_ctx , side_params ))
227- params = self ._keep_only (params , "anchor" ) # Probably unnecessary.
216+ params = self ._mask_all_but_step (params , "anchor" ) # Probably unnecessary.
228217 _ , means_hat = self .latent_codec ["y" ]._chunk (params )
229218 y_hat_anchors = quantize_ste (y - means_hat ) + means_hat
230- y_hat_anchors = self ._keep_only (y_hat_anchors , "anchor" )
219+ y_hat_anchors = self ._mask_all_but_step (y_hat_anchors , "anchor" )
231220
232221 y_ctx = self .context_prediction (y_hat_anchors )
233- y_ctx = self ._keep_only (y_ctx , "non_anchor" ) # Probably unnecessary.
222+ y_ctx = self ._mask_all_but_step (y_ctx , "non_anchor" ) # Probably unnecessary.
234223 params = self .entropy_parameters (self .merge (y_ctx , side_params ))
235224 y_out = self .latent_codec ["y" ](y , params )
236225
@@ -365,7 +354,7 @@ def _copy(self, dest: Tensor, src: Tensor, step: str) -> None:
365354 dest [..., 0 ::2 , 1 ::2 ] = src [..., 0 ::2 , 1 ::2 ]
366355 dest [..., 1 ::2 , 0 ::2 ] = src [..., 1 ::2 , 0 ::2 ]
367356
368- def _keep_only (self , y : Tensor , step : str ) -> Tensor :
357+ def _mask_all_but_step (self , y : Tensor , step : str ) -> Tensor :
369358 """Keep only pixels in the current step, and zero out the rest."""
370359 y = y .clone ()
371360 parity = self .anchor_parity if step == "anchor" else self .non_anchor_parity
@@ -382,7 +371,7 @@ def _mask_all(self, y: Tensor) -> Tensor:
382371 y [:] = 0
383372 return y
384373
385- def merge (self , * args ) :
374+ def merge (self , * args : Tensor ) -> Tensor :
386375 return torch .cat (args , dim = 1 )
387376
388377 def quantize (self , y : Tensor ) -> Tensor :
0 commit comments