You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/user/reproducibility.md
+20-3Lines changed: 20 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -79,17 +79,34 @@ Because TorchSim runs batched simulations, all systems in a batch share a single
79
79
80
80
If strict reproducibility is required, keep your batching setup fixed.
81
81
82
-
### Serialising the RNG state
82
+
### Serialising state for reproducible restarts
83
83
84
-
If you wish to be able to resume a session and ensure determinism you need to persist and reload the `torch.Generator` state. This can be done using `torch.save()` and `torch.Generator().set_state()`:
84
+
To resume a simulation and ensure determinism you need to persist and reload the complete state, including the `torch.Generator`RNG state. The simplest approach is to save the full state dict with `torch.save()`:
85
85
86
86
```python
87
+
from dataclasses import asdict
88
+
from torch_sim.integrators import MDState
89
+
87
90
# save
91
+
torch.save(asdict(state), "checkpoint.pt")
92
+
93
+
# restore (weights_only=False needed for torch.Generator in PyTorch 2.6+)
This captures positions, momenta, forces, energy, cell, and the `torch.Generator` in a single file. Since `torch.save()` uses pickle, the generator is serialised automatically.
98
+
99
+
> **Pickle caveat**: The `torch.Generator` object in the dict requires `weights_only=False` and may not unpickle across PyTorch versions. For portable checkpoints, save the tensors normally and extract the RNG state as a plain `uint8` tensor via `get_state()` — this loads with `weights_only=True` and is version-safe:
100
+
101
+
```python
102
+
# save RNG state as a plain uint8 tensor (no pickle needed)
0 commit comments