Skip to content

Commit 743d866

Browse files
committed
Merge remote-tracking branch 'origin/main' into upstream-fairchem
2 parents 3feb934 + c456ece commit 743d866

48 files changed

Lines changed: 3010 additions & 173 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ jobs:
6969
- { name: mattersim, test_path: "tests/models/test_mattersim.py" }
7070
- { name: metatomic, test_path: "tests/models/test_metatomic.py" }
7171
- { name: nequip, test_path: "tests/models/test_nequip_framework.py" }
72+
- { name: nequix, test_path: "tests/models/test_nequix.py" }
7273
- { name: orb, test_path: "tests/models/test_orb.py" }
7374
- { name: sevenn, test_path: "tests/models/test_sevennet.py" }
7475
exclude:
@@ -154,7 +155,7 @@ jobs:
154155
- name: Find example scripts
155156
id: set-matrix
156157
run: |
157-
EXAMPLES=$(find examples -name "*.py" | jq -R -s -c 'split("\n")[:-1]')
158+
EXAMPLES=$(find examples -name "*.py" -not -path "examples/benchmarking/*" | jq -R -s -c 'split("\n")[:-1]')
158159
echo "examples=$EXAMPLES" >> $GITHUB_OUTPUT
159160
160161
test-examples:

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ docs/reference/torch_sim.*
2828
*.hdf5
2929
*.traj
3030

31+
# ignore torch.save outputs
32+
*.pt
33+
3134
# coverage
3235
coverage.xml
3336
.coverage*

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ era. By rewriting the core primitives of atomistic simulation in Pytorch, it all
1414
orders of magnitude acceleration of popular machine learning potentials.
1515

1616
* Automatic batching and GPU memory management allowing significant simulation speedup
17-
* Support for MACE, Fairchem, SevenNet, ORB, MatterSim, graph-pes, and metatomic MLIP models
17+
* Support for MACE, Fairchem, SevenNet, ORB, MatterSim, graph-pes, metatomic, and Nequix MLIP models
1818
* Support for classical lennard jones, morse, and soft-sphere potentials
1919
* Molecular dynamics integration schemes like NVE, NVT Langevin, and NPT Langevin
2020
* Relaxation of atomic positions and cell with gradient descent and FIRE

docs/user/reproducibility.md

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,34 @@ Because TorchSim runs batched simulations, all systems in a batch share a single
7979

8080
If strict reproducibility is required, keep your batching setup fixed.
8181

82-
### Serialising the RNG state
82+
### Serialising state for reproducible restarts
8383

84-
If you wish to be able to resume a session and ensure determinism you need to persist and reload the `torch.Generator` state. This can be done using `torch.save()` and `torch.Generator().set_state()`:
84+
To resume a simulation and ensure determinism you need to persist and reload the complete state, including the `torch.Generator` RNG state. The simplest approach is to save the full state dict with `torch.save()`:
8585

8686
```python
87+
from dataclasses import asdict
88+
from torch_sim.integrators import MDState
89+
8790
# save
91+
torch.save(asdict(state), "checkpoint.pt")
92+
93+
# restore (weights_only=False needed for torch.Generator in PyTorch 2.6+)
94+
restored = MDState(**torch.load("checkpoint.pt", weights_only=False))
95+
```
96+
97+
This captures positions, momenta, forces, energy, cell, and the `torch.Generator` in a single file. Since `torch.save()` uses pickle, the generator is serialised automatically.
98+
99+
> **Pickle caveat**: The `torch.Generator` object in the dict requires `weights_only=False` and may not unpickle across PyTorch versions. For portable checkpoints, save the tensors normally and extract the RNG state as a plain `uint8` tensor via `get_state()` — this loads with `weights_only=True` and is version-safe:
100+
101+
```python
102+
# save RNG state as a plain uint8 tensor (no pickle needed)
88103
rng_state = state.rng.get_state()
89104
torch.save(rng_state, "rng_state.pt")
90105

91106
# restore
92107
gen = torch.Generator(device=state.device)
93-
gen.set_state(torch.load("rng_state.pt"))
108+
gen.set_state(torch.load("rng_state.pt", weights_only=True))
94109
state.rng = gen
95110
```
111+
112+
See the [reproducible restart tutorial](../../examples/tutorials/reproducible_restart_tutorial.py) for a complete worked example.

0 commit comments

Comments
 (0)