Skip to content

Commit d46d0e5

Browse files
committed
Clean up StateDict removal and fix standalone batch tests
Remove the residual dict-to-SimState fallback in both forward() methods now that StateDict support is dropped. Fix batch tests in the standalone package to use same-sized supercells (different lattice params) instead of mixed unit cell + supercell, matching the inline test fix.
1 parent 4dd5f07 commit d46d0e5

3 files changed

Lines changed: 24 additions & 31 deletions

File tree

python/metatomic_torch/metatomic/torch/torchsim.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -185,18 +185,9 @@ def forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: # noqa: C901,
185185
- 'stress': System stresses with shape [n_systems, 3, 3] if
186186
compute_stress=True
187187
"""
188-
sim_state = (
189-
state
190-
if isinstance(state, ts.SimState)
191-
else ts.SimState(**state, masses=torch.ones_like(state["positions"]))
192-
)
193-
194-
# Input validation is already done inside the forward method of the
195-
# AtomisticModel class, so we don't need to do it again here.
196-
197-
atomic_nums = sim_state.atomic_numbers
198-
cell = sim_state.row_vector_cell
199-
positions = sim_state.positions
188+
atomic_nums = state.atomic_numbers
189+
cell = state.row_vector_cell
190+
positions = state.positions
200191

201192
# Check dtype (metatomic models require a specific input dtype)
202193
if positions.dtype != self._dtype:
@@ -216,7 +207,7 @@ def forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: # noqa: C901,
216207
systems: list[System] = []
217208
strains = []
218209
for sys_idx in range(len(cell)):
219-
system_mask = sim_state.system_idx == sys_idx
210+
system_mask = state.system_idx == sys_idx
220211
system_positions = positions[system_mask]
221212
system_cell = cell[sys_idx]
222213
system_atomic_numbers = atomic_nums[system_mask]
@@ -237,7 +228,7 @@ def forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: # noqa: C901,
237228
positions=system_positions,
238229
types=system_atomic_numbers,
239230
cell=system_cell,
240-
pbc=sim_state.pbc,
231+
pbc=state.pbc,
241232
)
242233
)
243234

python/metatomic_torchsim/metatomic/torchsim/_calculator.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,9 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]:
184184
``"forces"`` (shape ``[n_atoms, 3]``, if ``compute_forces``), and
185185
``"stress"`` (shape ``[n_systems, 3, 3]``, if ``compute_stress``).
186186
"""
187-
sim_state = (
188-
state
189-
if isinstance(state, ts.SimState)
190-
else ts.SimState(**state, masses=torch.ones_like(state["positions"]))
191-
)
192-
193-
positions = sim_state.positions
194-
cell = sim_state.row_vector_cell
195-
atomic_nums = sim_state.atomic_numbers
187+
positions = state.positions
188+
cell = state.row_vector_cell
189+
atomic_nums = state.atomic_numbers
196190

197191
if positions.dtype != self._dtype:
198192
raise TypeError(
@@ -207,7 +201,7 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]:
207201
n_systems = len(cell)
208202

209203
for sys_idx in range(n_systems):
210-
mask = sim_state.system_idx == sys_idx
204+
mask = state.system_idx == sys_idx
211205
sys_positions = positions[mask]
212206
sys_cell = cell[sys_idx]
213207
sys_types = atomic_nums[mask]
@@ -231,7 +225,7 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]:
231225
positions=sys_positions,
232226
types=sys_types,
233227
cell=sys_cell,
234-
pbc=sim_state.pbc,
228+
pbc=state.pbc,
235229
)
236230
)
237231

python/metatomic_torchsim/tests/test_torchsim.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,21 @@ def test_forward_no_forces(lj_model, ni_atoms):
132132
assert "stress" in output
133133

134134

135-
def test_batched_forward(metatomic_model, ni_atoms):
136-
"""Forward pass handles batched systems correctly."""
135+
def _make_ni_atoms_2():
136+
"""Create a second Ni supercell (same size, different lattice parameter)."""
137137
import ase.build
138138

139-
atoms_2 = ase.build.bulk("Ni", "fcc", a=3.6, cubic=True)
139+
np.random.seed(0xCAFEBABE)
140+
atoms = ase.build.make_supercell(
141+
ase.build.bulk("Ni", "fcc", a=3.5, cubic=True), 2 * np.eye(3)
142+
)
143+
atoms.positions += 0.1 * np.random.rand(*atoms.positions.shape)
144+
return atoms
145+
146+
147+
def test_batched_forward(metatomic_model, ni_atoms):
148+
"""Forward pass handles batched systems correctly."""
149+
atoms_2 = _make_ni_atoms_2()
140150
sim_state = ts.io.atoms_to_state([ni_atoms, atoms_2], DEVICE, DTYPE)
141151
output = metatomic_model(sim_state)
142152

@@ -148,9 +158,7 @@ def test_batched_forward(metatomic_model, ni_atoms):
148158

149159
def test_energy_consistency_single_vs_batch(metatomic_model, ni_atoms):
150160
"""Energy from single system matches the corresponding entry in a batch."""
151-
import ase.build
152-
153-
atoms_2 = ase.build.bulk("Ni", "fcc", a=3.6, cubic=True)
161+
atoms_2 = _make_ni_atoms_2()
154162

155163
# single
156164
state_1 = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE)

0 commit comments

Comments
 (0)