-
Notifications
You must be signed in to change notification settings - Fork 844
Fix test_mutation_not_double_allocated to not use hardcoded indices #16667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -665,22 +665,51 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| et = to_edge(export(model, inputs, strict=True)).to_executorch() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # 0 and 11 should refer to the same tensor. 0 is the input, 11 is the output of copy_ | ||||||||||||||||||||||
| self.assertEqual( | ||||||||||||||||||||||
| et.executorch_program.execution_plan[0] | ||||||||||||||||||||||
| .values[0] | ||||||||||||||||||||||
| .val.allocation_info.memory_offset_low, | ||||||||||||||||||||||
| et.executorch_program.execution_plan[0] | ||||||||||||||||||||||
| .values[11] | ||||||||||||||||||||||
| .val.allocation_info.memory_offset_low, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| # The mutable buffer (5x5 float32 = 100 bytes) should not be double allocated. | ||||||||||||||||||||||
| # The input and output of copy_ should share the same memory location. | ||||||||||||||||||||||
| values = et.executorch_program.execution_plan[0].values | ||||||||||||||||||||||
| expected_buffer_size = 5 * 5 * 4 # 5x5 float32 | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Collect all tensor allocations by their (memory_id, offset) and track sizes | ||||||||||||||||||||||
| # Size is computed from tensor's sizes and scalar_type, not from allocation_info | ||||||||||||||||||||||
| # (memory_offset_low/high are low/high 32-bit parts of a 64-bit offset, not bounds) | ||||||||||||||||||||||
| scalar_type_sizes = { | ||||||||||||||||||||||
| 0: 1, # BYTE | ||||||||||||||||||||||
| 1: 1, # CHAR | ||||||||||||||||||||||
| 2: 2, # SHORT | ||||||||||||||||||||||
| 3: 4, # INT | ||||||||||||||||||||||
| 4: 8, # LONG | ||||||||||||||||||||||
| 5: 2, # HALF | ||||||||||||||||||||||
| 6: 4, # FLOAT | ||||||||||||||||||||||
| 7: 8, # DOUBLE | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| offset_to_indices = {} | ||||||||||||||||||||||
| for i, val in enumerate(values): | ||||||||||||||||||||||
| tensor = val.val | ||||||||||||||||||||||
| if hasattr(tensor, "allocation_info") and tensor.allocation_info: | ||||||||||||||||||||||
| alloc = tensor.allocation_info | ||||||||||||||||||||||
| # Compute tensor size from sizes and scalar_type | ||||||||||||||||||||||
| num_elements = 1 | ||||||||||||||||||||||
| for dim in tensor.sizes: | ||||||||||||||||||||||
| num_elements *= dim | ||||||||||||||||||||||
| element_size = scalar_type_sizes.get(int(tensor.scalar_type), 4) | ||||||||||||||||||||||
| size = num_elements * element_size | ||||||||||||||||||||||
| key = (alloc.memory_id, alloc.memory_offset) | ||||||||||||||||||||||
| if key not in offset_to_indices: | ||||||||||||||||||||||
| offset_to_indices[key] = {"indices": [], "size": size} | ||||||||||||||||||||||
|
||||||||||||||||||||||
| offset_to_indices[key] = {"indices": [], "size": size} | |
| offset_to_indices[key] = {"indices": [], "size": size} | |
| else: | |
| existing_size = offset_to_indices[key]["size"] | |
| self.assertEqual( | |
| existing_size, | |
| size, | |
| f"Inconsistent tensor sizes for shared memory location " | |
| f"{key}: previously {existing_size}, now {size}", | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The scalar_type_sizes dictionary is incomplete and missing many scalar types defined in ScalarType enum (exir/scalar_type.py). Missing types include: BOOL (11), QINT8 (12), QUINT8 (13), QINT32 (14), BFLOAT16 (15), QUINT4x2 (16), QUINT2x4 (17), BITS16 (22), FLOAT8E5M2 (23), FLOAT8E4M3FN (24), FLOAT8E5M2FNUZ (25), FLOAT8E4M3FNUZ (26), UINT16 (27), UINT32 (28), UINT64 (29), and COMPLEX32 (8), COMPLEX64 (9), COMPLEX128 (10).
While the test currently only expects float32 tensors, if other tensor types appear in the graph, the code will fall back to the default size of 4 bytes, which could produce incorrect size calculations and cause the test to incorrectly pass or fail. Consider either: