Skip to content
This repository was archived by the owner on Jan 15, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
*.ipynb filter=nbstripout
*.ipynb diff=ipynb

examples/zundel_i-PI.ipynb filter= diff=
54 changes: 48 additions & 6 deletions bindings/rascal/models/krr.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def train_gap_model(
X_sparse,
y_train,
self_contributions,
solver="Normal",
grad_train=None,
lambdas=None,
jitter=1e-8,
Expand Down Expand Up @@ -428,6 +429,15 @@ def train_gap_model(
jitter : double, optional
small jitter for the numerical stability of solving the linear system,
by default 1e-8
solver: string, optional
method used to solve the sparse KRR equations.
"Normal" uses a least-squares solver for the normal equations:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Normal" feels strange as a name. How about Standard or Direct or Direct Least Square?

(K_NM.T@K_NM + K_MM)^(-1) K_NM.T@Y
"RKHS" computes first the reproducing kernel features by diagonalizing K_MM
and computing P_NM = K_NM @ U_MM @ Lam_MM^(-1.2) and then solves the linear
problem for those (which is usually better conditioned)
(P_NM.T@P_NM + 1)^(-1) P_NM.T@Y
by default, "Normal"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the "QR" mode please - how is it different from "RKHS"?


Returns
-------
Expand Down Expand Up @@ -463,15 +473,47 @@ def train_gap_model(
F /= lambdas[1] / delta
Y = np.vstack([Y, F])

KMM[np.diag_indices_from(KMM)] += jitter
if solver == "Normal":
# Finds the KRR weights using the normal equations
K = KMM + np.dot(KNM.T, KNM)
Y = np.dot(KNM.T, Y)
weights = np.linalg.lstsq(K, Y, rcond=jitter)[0]
del K
if solver == "QR":
# Finds the KRR weights solving am extended system, see e.g
# Foster et al. JMLR (2009)
V = np.linalg.cholesky(KMM)
A = np.vstack([KNM, V.T])
b = np.vstack([Y, np.zeros((len(KMM),1))])
weights = np.linalg.lstsq(A, b, rcond=jitter)[0]
del V, A, b
elif solver == "RKHS":
# Finds the weights by computing explicitly the RKHS and
# solving a least-square model
eva, eve = np.linalg.eigh(KMM)
eva = eva[::-1]
eve = eve[:, ::-1]

# drop eigenvectors smaller than the jitter
nrkhs = len(np.where(eva / eva[0] > jitter)[0])
print("Retaining ", nrkhs, " RKHS components out of ", len(eva))
PKT = eve[:, :nrkhs] @ np.diag(1.0 / np.sqrt(eva[:nrkhs]))

# This would be the direct LR solution
# PKT = eve[:, :nrkhs] @ np.diag(1.0 / np.sqrt(eva[:nrkhs]))
# PNM = KNM @ PKT
# weights = PKT @ np.linalg.solve(PNM.T @ PNM + np.eye(nrkhs), PNM.T @ Y)

# ... but instead we use an alternative (equivalent) formulation using QR
A = np.vstack([ KNM@PKT, np.eye(nrkhs) ])
b = np.vstack([Y, np.zeros((nrkhs, 1))])
weights = PKT@np.linalg.lstsq(A, b, rcond=None)[0]

del PKT, eva, eve, A, b

K = KMM + np.dot(KNM.T, KNM)
Y = np.dot(KNM.T, Y)
weights = np.linalg.lstsq(K, Y, rcond=None)[0]
model = KRR(weights, kernel, X_sparse, self_contributions)

# avoid memory clogging
del K, KMM
K, KMM = [], []
del KMM

return model
Loading