Skip to content

Commit cbfbec5

Browse files
authored
Fix NPT Nose-Hoover things (#520)
1 parent c66e610 commit cbfbec5

4 files changed

Lines changed: 18 additions & 24 deletions

File tree

tests/test_integrators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def test_nvt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone
339339
temperatures_list = [t.tolist() for t in temperatures_tensor.T]
340340
assert torch.allclose(
341341
temperatures_tensor[-1],
342-
torch.tensor([300.0096, 299.7024], dtype=dtype),
342+
torch.tensor([290.3553, 289.9699], dtype=dtype),
343343
)
344344

345345
energies_tensor = torch.stack(energies)
@@ -728,7 +728,7 @@ def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone
728728
temperatures_list = [t.tolist() for t in temperatures_tensor.T]
729729
assert torch.allclose(
730730
temperatures_tensor[-1],
731-
torch.tensor([298.2752, 297.9444], dtype=dtype),
731+
torch.tensor([287.5729, 287.1330], dtype=dtype),
732732
)
733733

734734
energies_tensor = torch.stack(energies)

torch_sim/integrators/md.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def init_fn(
409409

410410
Q = (
411411
kT_batched.unsqueeze(-1)
412-
* torch.square(tau_batched).unsqueeze(-1) ** 2
412+
* torch.square(tau_batched).unsqueeze(-1)
413413
* torch.ones((n_systems, chain_length), dtype=dtype, device=device)
414414
)
415415
Q[:, 0] *= degrees_of_freedom

torch_sim/integrators/npt.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,9 +1179,9 @@ def _npt_nose_hoover_compute_cell_force(
11791179
internal_pressure = torch.trace(stress).unsqueeze(0).expand(n_systems)
11801180

11811181
# Compute force on cell coordinate per system
1182-
# F = alpha * KE - dU/dV - P*V*d
1182+
# F = alpha * (2 * KE) - dU/dV - P*V*d
11831183
return (
1184-
(alpha * KE_per_system)
1184+
(alpha * 2 * KE_per_system)
11851185
- (internal_pressure * volume)
11861186
- (external_pressure * volume * dim)
11871187
)
@@ -1226,21 +1226,18 @@ def _npt_nose_hoover_inner_step(
12261226
volume, volume_to_cell = _npt_nose_hoover_cell_info(state)
12271227
cell = volume_to_cell(volume)
12281228

1229-
# Get model output
1230-
state.cell = cell
1231-
model_output = model(state)
1232-
12331229
# First half step: Update momenta
1234-
n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems)
1235-
alpha = 1 + 1 / n_atoms_per_system # [n_systems]
1230+
# alpha = 1 + dim / degrees_of_freedom (3 * natoms - 3)
1231+
alpha = 1 + 3 / state.get_number_of_degrees_of_freedom() # [n_systems]
12361232

1233+
# Reuse stress from previous step since positions and cell unchanged
12371234
cell_force_val = _npt_nose_hoover_compute_cell_force(
12381235
alpha=alpha,
12391236
volume=volume,
12401237
positions=positions,
12411238
momenta=momenta,
12421239
masses=masses,
1243-
stress=model_output["stress"],
1240+
stress=state.stress,
12441241
external_pressure=external_pressure,
12451242
system_idx=state.system_idx,
12461243
)
@@ -1406,7 +1403,8 @@ def npt_nose_hoover_init(
14061403
)
14071404

14081405
# Compute total DOF for thermostat initialization and a zero KE placeholder
1409-
dof_per_system = torch.bincount(state.system_idx, minlength=n_systems) * dim
1406+
dof_per_system = state.get_number_of_degrees_of_freedom() - 3
1407+
14101408
KE_thermostat = ts.calc_kinetic_energy(
14111409
masses=state.masses, momenta=momenta, system_idx=state.system_idx
14121410
)
@@ -1612,13 +1610,12 @@ def npt_nose_hoover_invariant(
16121610
)
16131611

16141612
# Calculate degrees of freedom per system
1615-
n_atoms_per_system = torch.bincount(state.system_idx, minlength=state.n_systems)
1616-
dof_per_system = n_atoms_per_system * state.positions.shape[-1] # n_atoms * n_dim
1613+
dof_per_system = state.get_number_of_degrees_of_freedom()
16171614

16181615
# Initialize total energy with PE + KE
16191616
e_tot = e_pot + e_kin_per_system
16201617

1621-
# Add thermostat chain contributions (batched per system, DOF = n_atoms * 3)
1618+
# Add thermostat chain contributions (batched per system, DOF = 3 * n_atoms - 3)
16221619
e_tot += _compute_chain_energy(state.thermostat, kT, e_tot, dof_per_system)
16231620

16241621
# Add barostat chain contributions (batched per system, DOF = 1)

torch_sim/integrators/nvt.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,9 @@ def nvt_nose_hoover_init(
314314
masses=state.masses, momenta=momenta, system_idx=state.system_idx
315315
)
316316

317-
# Calculate degrees of freedom per system
318-
n_atoms_per_system = torch.bincount(state.system_idx)
319-
dof_per_system = (
320-
n_atoms_per_system * state.positions.shape[-1]
321-
) # n_atoms * n_dimensions
317+
# Calculate degrees of freedom per system (subtract 3 for COM motion,
318+
# matching LAMMPS compute_temp which uses dof = 3N - 3)
319+
dof_per_system = state.get_number_of_degrees_of_freedom() - 3
322320

323321
# Initialize state
324322
return NVTNoseHooverState.from_state(
@@ -431,9 +429,8 @@ def nvt_nose_hoover_invariant(
431429
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
432430
)
433431

434-
# Get system degrees of freedom per system
435-
n_atoms_per_system = torch.bincount(state.system_idx)
436-
dof = n_atoms_per_system * state.positions.shape[-1] # n_atoms * n_dimensions
432+
# Get system degrees of freedom per system (3N - 3 for COM correction)
433+
dof = state.get_number_of_degrees_of_freedom()
437434

438435
# Start with system energy
439436
e_tot = e_pot + e_kin

0 commit comments

Comments
 (0)