-
Notifications
You must be signed in to change notification settings - Fork 1
Add two function that deal with overlapping atoms #479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
13d4655
3b2d46a
cdcf346
667ad4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| """Utilities that operate purely geometric aspects of structures.""" | ||
|
|
||
| import numpy as np | ||
|
|
||
| from structuretoolkit.analyse import get_neighbors | ||
|
|
||
|
|
||
| def repulse( | ||
| structure, | ||
| min_dist=1.5, | ||
| step_size=0.2, | ||
| axis=None, | ||
| iterations: int = 100, | ||
| inplace=False, | ||
| ): | ||
| """Displace atoms to avoid minimum overlap. | ||
|
|
||
| Args: | ||
| structure (:class:`ase.Atoms`): | ||
| structure to modify | ||
| min_dist (float): | ||
| Minimum distance to enforce between atoms | ||
| step_size (float): | ||
| Maximum distance to displace atoms in one step | ||
| iterations (int): | ||
| Maximum number of displacements made before giving up | ||
| """ | ||
| if not inplace: | ||
| structure = structure.copy() | ||
| if axis is None: | ||
| axis = slice(None) | ||
| for _ in range(iterations): | ||
| neigh = get_neighbors(structure, num_neighbors=1) | ||
| dd = neigh.distances[:, 0] | ||
| if dd.min() > min_dist: | ||
| break | ||
|
|
||
| I = dd < min_dist | ||
|
|
||
| vv = neigh.vecs[I, 0, :] | ||
| vv /= dd[I, None] | ||
|
|
||
| disp = np.clip(min_dist - dd[I], 0, step_size) | ||
|
|
||
| displacement = disp[:, None] * vv # (N_close, 3) | ||
| structure.positions[I, axis] -= displacement[:, axis] | ||
|
|
||
| else: | ||
| raise RuntimeError(f"repulse did not converge within {iterations} iterations") | ||
|
|
||
| return structure | ||
|
|
||
|
|
||
| def merge( | ||
| structure: "ase.Atoms", cutoff: float = 1.8, iterations: int = 10 | ||
| ) -> "ase.Atoms": | ||
| """Merge pairs of atoms that are closer than ``cutoff`` by collapsing each | ||
| pair to their midpoint and deleting one of the two atoms. | ||
|
|
||
| The operation is applied repeatedly (up to ``iterations`` times) to handle | ||
| cases where a merge creates new close contacts. | ||
|
|
||
| .. note:: | ||
| The structure is modified **in place**. Pass a copy if you need the | ||
| original to remain unchanged. | ||
|
|
||
| Args: | ||
| structure (:class:`ase.Atoms`): | ||
| Structure to modify. | ||
| cutoff (float): | ||
| Distance threshold in Ångström below which two atoms are | ||
| considered clashing and will be merged. Defaults to ``1.8``. | ||
| iterations (int): | ||
| Maximum number of recursive merge passes. Defaults to ``10``. | ||
|
|
||
| Returns: | ||
| :class:`ase.Atoms`: The modified structure with clashing atom pairs | ||
| replaced by single atoms at their midpoints. | ||
| """ | ||
| neigh = get_neighbors(structure, 1) | ||
| clashing = np.argwhere(neigh.distances[:, 0] < cutoff).ravel() | ||
| if len(clashing) == 0: | ||
| return structure | ||
|
|
||
| moving = [] | ||
| deleting = [] | ||
|
|
||
| for c in clashing: | ||
| if c in deleting: | ||
| continue | ||
|
|
||
| moving.append(c) | ||
| deleting.append(neigh.indices[c, 0]) | ||
|
|
||
| structure.positions[moving] += neigh.vecs[moving, 0] / 2 | ||
| del structure[deleting] | ||
|
|
||
| if iterations > 0: | ||
| return merge(structure, cutoff=cutoff, iterations=iterations - 1) | ||
| return structure | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "merge", | ||
| "repulse", | ||
| ] | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,123 @@ | ||||||||||||||||||||||||||||||||
| import unittest | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||
| from ase import Atoms | ||||||||||||||||||||||||||||||||
| from ase.build import bulk | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| from structuretoolkit.analyse import get_neighbors | ||||||||||||||||||||||||||||||||
| from structuretoolkit.build.geometry import merge, repulse | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| class TestRepulse(unittest.TestCase): | ||||||||||||||||||||||||||||||||
| def setUp(self): | ||||||||||||||||||||||||||||||||
| self.atoms = bulk("Cu", cubic=True).repeat(5) | ||||||||||||||||||||||||||||||||
| self.atoms.rattle(1) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_noop(self): | ||||||||||||||||||||||||||||||||
| """If no atoms are violating min_dist, atoms should be unchanged.""" | ||||||||||||||||||||||||||||||||
| atoms = bulk("Cu", cubic=True).repeat(5) # perfect FCC, no rattle | ||||||||||||||||||||||||||||||||
| original_positions = atoms.positions.copy() | ||||||||||||||||||||||||||||||||
| # Cu nearest-neighbor ~2.55 Å, so min_dist=2.0 triggers no displacement | ||||||||||||||||||||||||||||||||
| result = repulse(atoms, min_dist=2.0) | ||||||||||||||||||||||||||||||||
| np.testing.assert_array_equal(result.positions, original_positions) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_inplace(self): | ||||||||||||||||||||||||||||||||
| """Input atoms should be copied depending on `inplace`.""" | ||||||||||||||||||||||||||||||||
| original_positions = self.atoms.positions.copy() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # inplace=False (default): original must be untouched, result is a copy | ||||||||||||||||||||||||||||||||
| result = repulse(self.atoms, inplace=False) | ||||||||||||||||||||||||||||||||
| np.testing.assert_array_equal(self.atoms.positions, original_positions) | ||||||||||||||||||||||||||||||||
| self.assertIsNot(result, self.atoms) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # inplace=True: result is the same object as the input | ||||||||||||||||||||||||||||||||
| result2 = repulse(self.atoms, inplace=True) | ||||||||||||||||||||||||||||||||
| self.assertIs(result2, self.atoms) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_iterations(self): | ||||||||||||||||||||||||||||||||
| """Should raise error if iterations exhausted.""" | ||||||||||||||||||||||||||||||||
| # min_dist=5.0 is far larger than any achievable spacing; step_size tiny | ||||||||||||||||||||||||||||||||
| # → convergence is impossible, so iterations will be exhausted | ||||||||||||||||||||||||||||||||
| with self.assertRaises(RuntimeError): | ||||||||||||||||||||||||||||||||
| repulse(self.atoms, min_dist=5.0, step_size=0.001, iterations=2) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_axis(self): | ||||||||||||||||||||||||||||||||
| """If axis given, the other axes should be exactly untouched.""" | ||||||||||||||||||||||||||||||||
| atoms = self.atoms.copy() | ||||||||||||||||||||||||||||||||
| original_y = atoms.positions[:, 1].copy() | ||||||||||||||||||||||||||||||||
| original_z = atoms.positions[:, 2].copy() | ||||||||||||||||||||||||||||||||
| # Modify inplace so we can inspect positions even if convergence fails | ||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||
| repulse(atoms, axis=0, inplace=True) | ||||||||||||||||||||||||||||||||
| except RuntimeError: | ||||||||||||||||||||||||||||||||
| pass # convergence irrelevant; we only care which axes were touched | ||||||||||||||||||||||||||||||||
| np.testing.assert_array_equal(atoms.positions[:, 1], original_y) | ||||||||||||||||||||||||||||||||
| np.testing.assert_array_equal(atoms.positions[:, 2], original_z) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_min_dist(self): | ||||||||||||||||||||||||||||||||
| """min_dist must be respected after a call to repulse.""" | ||||||||||||||||||||||||||||||||
| min_dist = 1.5 | ||||||||||||||||||||||||||||||||
| result = repulse(self.atoms, min_dist=min_dist) | ||||||||||||||||||||||||||||||||
| neigh = get_neighbors(result, num_neighbors=1) | ||||||||||||||||||||||||||||||||
| self.assertGreaterEqual(neigh.distances[:, 0].min(), min_dist) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _two_atom_structure(d: float) -> Atoms: | ||||||||||||||||||||||||||||||||
| """Return a two-Cu-atom cell with atoms separated by ``d`` Å along x.""" | ||||||||||||||||||||||||||||||||
| return Atoms("Cu2", positions=[[0, 0, 0], [d, 0, 0]], cell=[20, 20, 20], pbc=True) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| class TestMerge(unittest.TestCase): | ||||||||||||||||||||||||||||||||
| def test_noop(self): | ||||||||||||||||||||||||||||||||
| """Perfect FCC Cu has no contacts below default cutoff; structure is unchanged.""" | ||||||||||||||||||||||||||||||||
| atoms = bulk("Cu", cubic=True).repeat(3) | ||||||||||||||||||||||||||||||||
| original_positions = atoms.positions.copy() | ||||||||||||||||||||||||||||||||
| result = merge(atoms) | ||||||||||||||||||||||||||||||||
| self.assertEqual(len(result), len(atoms)) | ||||||||||||||||||||||||||||||||
| np.testing.assert_array_equal(result.positions, original_positions) | ||||||||||||||||||||||||||||||||
|
Comment on lines
+71
to
+77
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test compares length against same object due to in-place modification. Since 💚 Proposed fix def test_noop(self):
"""Perfect FCC Cu has no contacts below default cutoff; structure is unchanged."""
atoms = bulk("Cu", cubic=True).repeat(3)
original_positions = atoms.positions.copy()
+ original_len = len(atoms)
result = merge(atoms)
- self.assertEqual(len(result), len(atoms))
+ self.assertEqual(len(result), original_len)
np.testing.assert_array_equal(result.positions, original_positions)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_reduces_count(self): | ||||||||||||||||||||||||||||||||
| """Two atoms within cutoff are collapsed into one.""" | ||||||||||||||||||||||||||||||||
| atoms = _two_atom_structure(0.5) | ||||||||||||||||||||||||||||||||
| result = merge(atoms, cutoff=1.8) | ||||||||||||||||||||||||||||||||
| self.assertEqual(len(result), 1) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_midpoint(self): | ||||||||||||||||||||||||||||||||
| """The surviving atom must sit at the midpoint of the original pair.""" | ||||||||||||||||||||||||||||||||
| atoms = _two_atom_structure(1.0) | ||||||||||||||||||||||||||||||||
| result = merge(atoms, cutoff=1.8) | ||||||||||||||||||||||||||||||||
| self.assertEqual(len(result), 1) | ||||||||||||||||||||||||||||||||
| np.testing.assert_allclose(result.positions[0], [0.5, 0, 0], atol=1e-10) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_cutoff_respected(self): | ||||||||||||||||||||||||||||||||
| """Atoms just beyond the cutoff must not be merged.""" | ||||||||||||||||||||||||||||||||
| atoms = _two_atom_structure(2.0) | ||||||||||||||||||||||||||||||||
| result = merge(atoms, cutoff=1.8) | ||||||||||||||||||||||||||||||||
| self.assertEqual(len(result), 2) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_multiple_pairs(self): | ||||||||||||||||||||||||||||||||
| """All clashing pairs in the structure are merged in one call.""" | ||||||||||||||||||||||||||||||||
| # Two independent close pairs, well separated from each other | ||||||||||||||||||||||||||||||||
| atoms = Atoms( | ||||||||||||||||||||||||||||||||
| "Cu4", | ||||||||||||||||||||||||||||||||
| positions=[[0, 0, 0], [0.5, 0, 0], [10, 0, 0], [10.5, 0, 0]], | ||||||||||||||||||||||||||||||||
| cell=[30, 30, 30], | ||||||||||||||||||||||||||||||||
| pbc=True, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| result = merge(atoms, cutoff=1.8) | ||||||||||||||||||||||||||||||||
| self.assertEqual(len(result), 2) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def test_iterations_zero_stops_early(self): | ||||||||||||||||||||||||||||||||
| """With iterations=0 only one pass runs; further clashes are left unresolved.""" | ||||||||||||||||||||||||||||||||
| # Three atoms in a row: 0–1 and 1–2 clash. First pass merges 0+1 → 0.25. | ||||||||||||||||||||||||||||||||
| # The merged atom is now ~9.75 Å from atom 2, so one pass is enough here | ||||||||||||||||||||||||||||||||
| # — we just verify the function returns without error. | ||||||||||||||||||||||||||||||||
| atoms = Atoms( | ||||||||||||||||||||||||||||||||
| "Cu3", | ||||||||||||||||||||||||||||||||
| positions=[[0, 0, 0], [0.5, 0, 0], [10, 0, 0]], | ||||||||||||||||||||||||||||||||
| cell=[20, 20, 20], | ||||||||||||||||||||||||||||||||
| pbc=True, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| result = merge(atoms, cutoff=1.8, iterations=0) | ||||||||||||||||||||||||||||||||
| # At least one merge happened | ||||||||||||||||||||||||||||||||
| self.assertLess(len(result), 3) | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix undefined
asein type hints.The type hints reference
ase.Atomsbutaseis not imported. Since this is used only for documentation purposes, use a string literal or add aTYPE_CHECKINGimport.🔧 Proposed fix using string literal (simplest)
Or with proper TYPE_CHECKING import:
📝 Committable suggestion
🧰 Tools
🪛 Ruff (0.15.9)
[error] 55-55: Undefined name
ase(F821)
[error] 56-56: Undefined name
ase(F821)
🤖 Prompt for AI Agents