diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index ce20de8f820..dc25d1ad7c6 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -665,22 +665,51 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: et = to_edge(export(model, inputs, strict=True)).to_executorch() - # 0 and 11 should refer to the same tensor. 0 is the input, 11 is the output of copy_ - self.assertEqual( - et.executorch_program.execution_plan[0] - .values[0] - .val.allocation_info.memory_offset_low, - et.executorch_program.execution_plan[0] - .values[11] - .val.allocation_info.memory_offset_low, - ) + # The mutable buffer (5x5 float32 = 100 bytes) should not be double allocated. + # The input and output of copy_ should share the same memory location. + values = et.executorch_program.execution_plan[0].values + expected_buffer_size = 5 * 5 * 4 # 5x5 float32 + + # Collect all tensor allocations by their (memory_id, offset) and track sizes + # Size is computed from tensor's sizes and scalar_type, not from allocation_info + # (memory_offset_low/high are low/high 32-bit parts of a 64-bit offset, not bounds) + scalar_type_sizes = { + 0: 1, # BYTE + 1: 1, # CHAR + 2: 2, # SHORT + 3: 4, # INT + 4: 8, # LONG + 5: 2, # HALF + 6: 4, # FLOAT + 7: 8, # DOUBLE + } + offset_to_indices = {} + for i, val in enumerate(values): + tensor = val.val + if hasattr(tensor, "allocation_info") and tensor.allocation_info: + alloc = tensor.allocation_info + # Compute tensor size from sizes and scalar_type + num_elements = 1 + for dim in tensor.sizes: + num_elements *= dim + element_size = scalar_type_sizes.get(int(tensor.scalar_type), 4) + size = num_elements * element_size + key = (alloc.memory_id, alloc.memory_offset) + if key not in offset_to_indices: + offset_to_indices[key] = {"indices": [], "size": size} + offset_to_indices[key]["indices"].append(i) + + # Find shared allocations matching the mutable buffer size (before/after copy_) + mutable_buffer_shares = [ + info + for info in offset_to_indices.values() + if len(info["indices"]) == 2 and info["size"] == expected_buffer_size + ] self.assertEqual( - et.executorch_program.execution_plan[0] - .values[0] - .val.allocation_info.memory_offset_high, - et.executorch_program.execution_plan[0] - .values[11] - .val.allocation_info.memory_offset_high, + len(mutable_buffer_shares), + 1, + f"Expected exactly one shared allocation of size {expected_buffer_size} " + f"with 2 values (copy_ input/output), found: {mutable_buffer_shares}", ) def test_mutable_buffers_infinite_lifespan(self) -> None: