Skip to content

Commit 5e04fb8

Browse files
YodaEmbeddingfracape
authored andcommitted
refactor: simplify CheckerboardLatentCodec
1 parent 6773ad0 commit 5e04fb8

File tree

1 file changed

+16
-27
lines changed

1 file changed

+16
-27
lines changed

compressai/latent_codecs/checkerboard.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _forward_onepass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
145145
This method uses uniform noise to roughly model quantization.
146146
"""
147147
y_hat = self.quantize(y)
148-
y_ctx = self._keep_only(self.context_prediction(y_hat), "non_anchor")
148+
y_ctx = self._mask_all_but_step(self.context_prediction(y_hat), "non_anchor")
149149
params = self.entropy_parameters(self.merge(y_ctx, side_params))
150150
y_out = self.latent_codec["y"](y, params)
151151
return {
@@ -167,39 +167,28 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
167167
To compute ``y_hat_anchors``, we need the predicted ``means_hat``:
168168
``y_hat = quantize_ste(y - means_hat) + means_hat``.
169169
Thus, two passes of ``entropy_parameters`` are necessary.
170-
171170
"""
172171
B, C, H, W = y.shape
173-
174172
params = y.new_zeros((B, C * 2, H, W))
175-
176173
y_hat_ = []
177174

178-
# NOTE: The _i variables contain only the current step's pixels.
179-
# i=0: step=anchor
180-
# i=1: step=non_anchor
181-
182175
for step in ("anchor", "non_anchor"):
176+
# Determine y_ctx for current step.
183177
if step == "anchor":
184-
y_ctx = self._y_ctx_zero(y)
178+
y_ctx_i = self._y_ctx_zero(y)
185179
else: # step == "non_anchor"
186-
y_hat_anchors = y_hat_[0]
187-
y_ctx = self.context_prediction(y_hat_anchors)
180+
y_ctx_i = self.context_prediction(y_hat_[0])
188181

189-
params_i = self.entropy_parameters(self.merge(y_ctx, side_params))
190-
191-
# Save params for current step. This is later used for entropy estimation.
182+
# Determine params for current step.
183+
params_i = self.entropy_parameters(self.merge(y_ctx_i, side_params))
184+
params_i = self._mask_all_but_step(params_i, step)
192185
self._copy(params, params_i, step)
193186

194-
# Keep only elements needed for current step.
195-
# It's not necessary to mask the rest out just yet, but it doesn't hurt.
196-
params_i = self._keep_only(params_i, step)
197-
y_i = self._keep_only(y, step)
198-
199-
# Determine y_hat for current step, and mask out the other pixels.
187+
# Determine y_hat for current step.
200188
_, means_i = self.latent_codec["y"]._chunk(params_i)
201-
y_hat_i = self._keep_only(quantize_ste(y_i - means_i) + means_i, step)
202-
189+
y_i = self._mask_all_but_step(y, step)
190+
y_hat_i = quantize_ste(y_i - means_i) + means_i
191+
y_hat_i = self._mask_all_but_step(y_hat_i, step)
203192
y_hat_.append(y_hat_i)
204193

205194
[y_hat_anchors, y_hat_non_anchors] = y_hat_
@@ -224,13 +213,13 @@ def _forward_twopass_faster(self, y: Tensor, side_params: Tensor) -> Dict[str, A
224213
"""
225214
y_ctx = self._y_ctx_zero(y)
226215
params = self.entropy_parameters(self.merge(y_ctx, side_params))
227-
params = self._keep_only(params, "anchor") # Probably unnecessary.
216+
params = self._mask_all_but_step(params, "anchor") # Probably unnecessary.
228217
_, means_hat = self.latent_codec["y"]._chunk(params)
229218
y_hat_anchors = quantize_ste(y - means_hat) + means_hat
230-
y_hat_anchors = self._keep_only(y_hat_anchors, "anchor")
219+
y_hat_anchors = self._mask_all_but_step(y_hat_anchors, "anchor")
231220

232221
y_ctx = self.context_prediction(y_hat_anchors)
233-
y_ctx = self._keep_only(y_ctx, "non_anchor") # Probably unnecessary.
222+
y_ctx = self._mask_all_but_step(y_ctx, "non_anchor") # Probably unnecessary.
234223
params = self.entropy_parameters(self.merge(y_ctx, side_params))
235224
y_out = self.latent_codec["y"](y, params)
236225

@@ -365,7 +354,7 @@ def _copy(self, dest: Tensor, src: Tensor, step: str) -> None:
365354
dest[..., 0::2, 1::2] = src[..., 0::2, 1::2]
366355
dest[..., 1::2, 0::2] = src[..., 1::2, 0::2]
367356

368-
def _keep_only(self, y: Tensor, step: str) -> Tensor:
357+
def _mask_all_but_step(self, y: Tensor, step: str) -> Tensor:
369358
"""Keep only pixels in the current step, and zero out the rest."""
370359
y = y.clone()
371360
parity = self.anchor_parity if step == "anchor" else self.non_anchor_parity
@@ -382,7 +371,7 @@ def _mask_all(self, y: Tensor) -> Tensor:
382371
y[:] = 0
383372
return y
384373

385-
def merge(self, *args):
374+
def merge(self, *args: Tensor) -> Tensor:
386375
return torch.cat(args, dim=1)
387376

388377
def quantize(self, y: Tensor) -> Tensor:

0 commit comments

Comments
 (0)