@@ -1179,9 +1179,9 @@ def _npt_nose_hoover_compute_cell_force(
11791179 internal_pressure = torch .trace (stress ).unsqueeze (0 ).expand (n_systems )
11801180
11811181 # Compute force on cell coordinate per system
1182- # F = alpha * KE - dU/dV - P*V*d
1182+ # F = alpha * (2 * KE) - dU/dV - P*V*d
11831183 return (
1184- (alpha * KE_per_system )
1184+ (alpha * 2 * KE_per_system )
11851185 - (internal_pressure * volume )
11861186 - (external_pressure * volume * dim )
11871187 )
@@ -1226,21 +1226,18 @@ def _npt_nose_hoover_inner_step(
12261226 volume , volume_to_cell = _npt_nose_hoover_cell_info (state )
12271227 cell = volume_to_cell (volume )
12281228
1229- # Get model output
1230- state .cell = cell
1231- model_output = model (state )
1232-
12331229 # First half step: Update momenta
1234- n_atoms_per_system = torch . bincount ( state . system_idx , minlength = state . n_systems )
1235- alpha = 1 + 1 / n_atoms_per_system # [n_systems]
1230+ # alpha = 1 + dim / degrees_of_freedom (3 * natoms - 3 )
1231+ alpha = 1 + 3 / state . get_number_of_degrees_of_freedom () # [n_systems]
12361232
1233+ # Reuse stress from previous step since positions and cell unchanged
12371234 cell_force_val = _npt_nose_hoover_compute_cell_force (
12381235 alpha = alpha ,
12391236 volume = volume ,
12401237 positions = positions ,
12411238 momenta = momenta ,
12421239 masses = masses ,
1243- stress = model_output [ " stress" ] ,
1240+ stress = state . stress ,
12441241 external_pressure = external_pressure ,
12451242 system_idx = state .system_idx ,
12461243 )
@@ -1406,7 +1403,8 @@ def npt_nose_hoover_init(
14061403 )
14071404
14081405 # Compute total DOF for thermostat initialization and a zero KE placeholder
1409- dof_per_system = torch .bincount (state .system_idx , minlength = n_systems ) * dim
1406+ dof_per_system = state .get_number_of_degrees_of_freedom () - 3
1407+
14101408 KE_thermostat = ts .calc_kinetic_energy (
14111409 masses = state .masses , momenta = momenta , system_idx = state .system_idx
14121410 )
@@ -1612,13 +1610,12 @@ def npt_nose_hoover_invariant(
16121610 )
16131611
16141612 # Calculate degrees of freedom per system
1615- n_atoms_per_system = torch .bincount (state .system_idx , minlength = state .n_systems )
1616- dof_per_system = n_atoms_per_system * state .positions .shape [- 1 ] # n_atoms * n_dim
1613+ dof_per_system = state .get_number_of_degrees_of_freedom ()
16171614
16181615 # Initialize total energy with PE + KE
16191616 e_tot = e_pot + e_kin_per_system
16201617
1621- # Add thermostat chain contributions (batched per system, DOF = n_atoms * 3)
1618+ # Add thermostat chain contributions (batched per system, DOF = 3 * n_atoms - 3)
16221619 e_tot += _compute_chain_energy (state .thermostat , kT , e_tot , dof_per_system )
16231620
16241621 # Add barostat chain contributions (batched per system, DOF = 1)
0 commit comments