Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1d830b8
implementation of multi-stage time integrators
fernanvr May 5, 2025
7f087b3
Merge remote-tracking branch 'upstream/main' into multi-stage-time-in…
fernanvr Jun 13, 2025
214d882
Return of first PR comments
fernanvr Jun 13, 2025
d6c4d4a
Return of first PR comments
fernanvr Jun 13, 2025
78f8a0b
2nd PR revision
fernanvr Jun 23, 2025
1c9d517
2nd PR revision
fernanvr Jun 23, 2025
11db48b
2rd PR, updating tests and suggestions of 2nd PR revision
fernanvr Jun 25, 2025
83dfb04
3rd PR, updating tests and suggestions of 2nd PR revision
fernanvr Jun 25, 2025
d47a106
4th PR revision, code refining and improving tests
fernanvr Jun 25, 2025
1f93a45
4th PR revision, code refining and improving tests
fernanvr Jun 25, 2025
eea3a52
5th PR revision, one suggestion from EdC and improving tests
fernanvr Jun 26, 2025
11d1429
including two more Runge-Kutta methods and improving tests: checking …
fernanvr Jul 1, 2025
4637ac2
changes to consider coupled Multistage equations
fernanvr Jul 16, 2025
ac1da7e
Improvements of the HORK_EXP
fernanvr Aug 15, 2025
e9b3533
Merge branch 'main' into multi-stage-time-integrator
fernanvr Aug 15, 2025
dc3dd77
Merge remote-tracking branch 'upstream/main' into multi-stage-time-in…
fernanvr Oct 8, 2025
1fd4a02
tuples, improved class names, extensive tests
fernanvr Oct 8, 2025
a0c45c1
improving spacing in some tests
fernanvr Oct 8, 2025
93c6e3f
Add MFE time stepping Jupyter notebook
fernanvr Oct 23, 2025
ef8d1ac
Remove MFE_time_size.ipynb notebook
fernanvr Oct 23, 2025
fa5acac
Update multistage implementation and tests
fernanvr Oct 29, 2025
143d0c2
Return of first PR comments
fernanvr Jun 13, 2025
5f67b91
updating small changes from EdC review on 26-03-2026
fernanvr Mar 26, 2026
552fd7f
Isolate multistage-related changes only
fernanvr Mar 27, 2026
e9d2000
merging two classes of Runge-Kutta
fernanvr Mar 27, 2026
f7c9ea3
Merge full multistage history while keeping clean tree
fernanvr Mar 27, 2026
cf1003c
fixing test_multistage file
fernanvr Apr 6, 2026
1fd480b
Remove devito/ir/equations/algorithms.py and devito/operator/operator…
fernanvr Apr 9, 2026
a875224
implemented suggestions of EdC and Fabio
fernanvr Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 89 additions & 62 deletions devito/types/multistage.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file should be moved to somewhere like devito/timestepping/rungekutta.py or devito/timestepping/explicitmultistage.py that way additional timesteppers can be contributed as new files. (I'm thinking about implicit multistage, backward difference formulae etc...)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed the class to HighOrderRungeKuttaExponential. I realize the name might be confusing since this particular Runge-Kutta is explicit, but “EXP” was intended to highlight the exponential aspect. I’ve also updated the other class names based on your suggestions.

Regarding the file location, it’s currently in /types as recommended by @mloubout (see suggestion). Personally, I think both /timestepping and /types are reasonable options. Perhaps we can discuss this with @EdCaunt and @FabioLuporini to reach a consensus.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this file doesn't belong to types/

based on https://github.com/devitocodes/devito/pull/2599/changes#r3043562368, we might add it to ir/dsl/rungekutta.py

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from devito.symbolics import uxreplace
from numpy import number
from devito.types.array import Array
from devito.types.dense import Function
from devito.types.constant import Constant
from types import MappingProxyType

method_registry = {}
Expand Down Expand Up @@ -51,7 +53,7 @@ class MultiStage(Eq):
of update expressions for each stage in the integration process.
"""

def __new__(cls, lhs, rhs, **kwargs):
def __new__(cls, lhs, rhs, source=None, degree=6, **kwargs):
if not isinstance(lhs, list):
lhs=[lhs]
rhs=[rhs]
Expand All @@ -61,6 +63,8 @@ def __new__(cls, lhs, rhs, **kwargs):
obj._eq = [Eq(lhs[i], rhs[i]) for i in range(len(lhs))]
obj._lhs = lhs
obj._rhs = rhs
obj._deg = degree
obj._src = source

return obj

Expand All @@ -79,6 +83,16 @@ def rhs(self):
"""Return list of right-hand sides."""
return self._rhs

@property
def deg(self):
"""Return list of right-hand sides."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For immutability, these should be tuple. Ditto with the next property. They should probably be made into tuples as early as possible

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

return self._deg

@property
def src(self):
"""Return list of right-hand sides."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeated docstring?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

return self._src

def _evaluate(self, **kwargs):
raise NotImplementedError(
f"_evaluate() must be implemented in the subclass {self.__class__.__name__}")
Expand Down Expand Up @@ -115,7 +129,9 @@ class RK(MultiStage):
Number of stages in the RK method, inferred from `b`.
"""

def __init__(self, a: list[list[float | number]], b: list[float | number], c: list[float | number], lhs, rhs, **kwargs) -> None:
CoeffsBC = list[float | number]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does number superclass np.number?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, is the same, I changed it to np.number to avoid confusion

CoeffsA = list[CoeffsBC]
def __init__(self, a: CoeffsA, b: CoeffsBC, c: CoeffsBC, lhs, rhs, **kwargs) -> None:
self.a, self.b, self.c = a, b, c

@property
Expand All @@ -132,19 +148,18 @@ def _evaluate(self, **kwargs):

Returns
-------
list of Eq
list of Devito Eq objects
A list of SymPy Eq objects representing:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: they will be Devito Eq objects

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

- `s` stage equations of the form `k_i = rhs evaluated at intermediate state`
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
"""
n_eq=len(self.eq)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs whitespace for flake8

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

u = [i.function for i in self.lhs]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making this a tuple and a property of the class

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did another thing where instead of having the lhs of the equations, save the function. Because that is the only need I have for the lhs.

grid = [u[i].grid for i in range(n_eq)]
t = grid[0].time_dim
t = u[0].grid.time_dim
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grids = {f.grid for f in u}
if not len(grids) == 1:
    raise ValueError("Cannot construct multi-stage time integrator for Functions on disparate grids")
grid = grids.pop()
t = grid.time_dim

would be safer

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm wrong, but I think the functions should be allowed to be defined in different grids, like different staggered grids... the method should work also for those cases.

dt = t.spacing

# Create temporary Functions to hold each stage
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: these are Array now

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right!

k = [[Array(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', dimensions=grid[j].dimensions, grid=grid[j], dtype=u[j].dtype) for i in range(self.s)]
k = [[Array(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', dimensions=u[j].grid.dimensions, grid=u[j].grid, dtype=u[j].dtype) for i in range(self.s)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k = [[Array(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', dimensions=grid.dimensions, grid=grid, dtype=f.dtype) for f in u]

or similar should be sufficient in combination with my previous comment

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even with that, all those combination are needed, because for each equation are required Arrays for all the Runge-Kutta stages.

for j in range(n_eq)]

stage_eqs = []
Expand Down Expand Up @@ -214,8 +229,8 @@ class RK32(RK):
b = [0, 0, 1]
c = [0, 1/2, 1/2]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should definitely be tuple. Exposing a mutable data structure at the class level here is potentially dangerous

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def __init__(self, *args, **kwargs):
super().__init__(a=self.a, b=self.b, c=self.c, **kwargs)
def __init__(self, lhs, rhs, **kwargs):
super().__init__(a=self.a, b=self.b, c=self.c, lhs=lhs, rhs=rhs, **kwargs)


@register_method
Expand Down Expand Up @@ -249,12 +264,12 @@ class RK97(RK):
5963949/25894400, 50000000000/599799373173, 28487/712800]
c = [0, 4/63, 2/21, 1/7, 7/17, 13/24, 7/9, 91/100, 1]

def __init__(self, *args, **kwargs):
super().__init__(a=self.a, b=self.b, c=self.c, **kwargs)
def __init__(self, lhs, rhs, **kwargs):
super().__init__(a=self.a, b=self.b, c=self.c, lhs=lhs, rhs=rhs, **kwargs)


@register_method
class HORK(MultiStage):
class HORK_EXP(MultiStage):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HORK_Exp?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HighOrderRungeKuttaExplicit perhaps

Be explicit!

Use CamelCase for class definitions and names_with_underscores for variables and functions.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if we want to start distinguishing between explicit and implicit timesteppers we should be using a class hierarchy to do so

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, changed it. Actually, the EXP was for Exponential, not Explicit. But I agree that if/when implicit method are implemented, a class hierarchy is the best option.

# In construction
"""
n stages Runge-Kutta (HORK) time integration method.
Expand All @@ -271,8 +286,19 @@ class HORK(MultiStage):
Time positions of intermediate stages.
"""

def source_derivatives(self, src_index, t, **kwargs):

# Compute the base wavelet function
f_deriv = [[self.src[i][1] for i in range(len(self.src))]]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have a couple of one-liners like this where you use range(len()) unnecessarily. If you need a counter, there is always enumerate. However, in this case, the loop can just be:

f_deriv = [[src[1] for src in self.src]]

You might want to go through your PR and tidy them up

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I went through the PR and changed a few of the range(len()) inside the loops


# Compute derivatives up to order p
for _ in range(self.deg - 1):
f_deriv.append([f_deriv[-1][i].diff(t) for i in range(len(src_index))])

f_deriv.reverse()
Comment thread
EdCaunt marked this conversation as resolved.
return f_deriv

def ssprk_alpha(mu=1, **kwargs):
def ssprk_alpha(self, mu=1):
"""
Computes the coefficients for the Strong Stability Preserving Runge-Kutta (SSPRK) method.

Expand All @@ -287,18 +313,33 @@ def ssprk_alpha(mu=1, **kwargs):
numpy.ndarray
Array of SSPRK coefficients.
"""
degree=kwargs.get('degree')

alpha = [0]*degree
alpha = [0]*self.deg
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.zeros(self.deg)?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could do that, but I thought that in cases like this is was better to use list. Isn't?

alpha[0] = 1.0 # Initial coefficient

for i in range(1, degree):
alpha[i] = 1 / (mu * (i + 1)) * alpha[i - 1]
alpha[1:i] = 1 / (mu * list(range(1, i))) * alpha[:i - 1]
for i in range(1, self.deg):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Some comments would help here

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comments, tell me if that is what you had in mind...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All looks reasonable

alpha[i] = 1/(mu*(i+1))*alpha[i-1]
alpha[1:i] = [1/(mu*j)*alpha[j-1] for j in range(1,i)]
alpha[0] = 1 - sum(alpha[1:i + 1])

return alpha
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't look to be an array getting returned

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed



def source_inclusion(self, u, k, src_index, src_deriv, e_p, t, dt, mu, n_eq):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function could also use a few more comments. Additionally, is there any way to reduce the number of args? Possibly some of these args should be properties of self

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did something in this direction...


src_lhs = [uxreplace(self.rhs[i], {u[m]: k[m] for m in range(n_eq)}) for i in range(n_eq)]

p = len(src_deriv)

for i in range(p):
if e_p[i] != 0:
for j in range(len(src_index)):
src_lhs[src_index[j]] += self.src[j][0]*src_deriv[i][j].subs({t: t * dt})*e_p[i]
e_p = [e_p[i]+mu*dt*e_p[i + 1] for i in range(p - 1)]+[e_p[-1]]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would append [e_p[-1]] on the next line for readability

Copy link
Copy Markdown
Author

@fernanvr fernanvr Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but if I do that, e_p will be updated and e_p[-1] is not going to be the value I need...


return src_lhs, e_p


def _evaluate(self, **kwargs):
"""
Generate the stage-wise equations for a Runge-Kutta time integration method.
Expand All @@ -315,66 +356,52 @@ def _evaluate(self, **kwargs):
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
"""

u = self.lhs.function
rhs = self.rhs
grid = u.grid
t = grid.time_dim
n_eq=len(self.eq)
u = [i.function for i in self.lhs]
t = u[0].grid.time_dim
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the tweaks to this in my earlier comment. If it keeps popping up in different methods, then it should probably be a property of self

Copy link
Copy Markdown
Author

@fernanvr fernanvr Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hummm, I get your point, in this case, n_eq, t and dt should be properties of self

dt = t.spacing

an_eq = range(len(U0))
# Create a temporary Array for each variable to save the time stages
# k = [Array(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', dimensions=u[i].grid.dimensions, grid=u[i].grid, dtype=u[i].dtype) for i in range(n_eq)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be Array? Additionally, you should pull the symbol registry out at the top of the function. Furthermore, should these Arrays only be generated during lower_multistage?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to be Array, but since there are still some problems when using Array I left it comment.

Related to the other comment, I think I could define the Arrays earlier and pass them as arguments. Is that what you have in mind?

k = [Function(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', grid=u[i].grid, space_order=2, time_order=1, dtype=u[i].dtype) for i in range(n_eq)]
k_old = [Function(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', grid=u[i].grid, space_order=2, time_order=1, dtype=u[i].dtype) for i in range(n_eq)]

# Compute SSPRK coefficients
alpha = np.array(ssprk_alpha(mu, degree), dtype=np.float64)
mu = 1
alpha = self.ssprk_alpha(mu=mu)
Comment thread
EdCaunt marked this conversation as resolved.

# Initialize symbolic differentiation for source terms
t_var = sym.Symbol('t_var')
src_deriv = aux_fun.derivates_f(degree, f0)
src_index_map={val: i for i, val in enumerate(u)}
src_index = [src_index_map[val] for val in [self.src[i][2] for i in range(len(self.src))]]
src_deriv = self.source_derivatives(src_index, t, **kwargs)

# Expansion coefficients for stability control
e_p = [0] * degree
e_p = [0] * self.deg
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.zeros?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could do that, but I thought that in cases like this is was better to use list. Isn't?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably fine then

eta = 1
e_p[-1] = 1 / eta

# Initialize approximation and auxiliary variable
approx = [U0[i] * alpha[0] for i in n_eq]
aux = U0

# Perform Runge-Kutta steps
for i in range(1, degree - 1):
system_op, e_p = sys_op_extended(aux, x, y, z, param_fun, system, fd_order, src_spat, src_deriv, t, dt, t_var, e_p)
aux = [aux[j] + mu * dt * system_op[j] for j in n_eq]
approx = [approx[j] + aux[j] * alpha[i] for j in n_eq]

# Final Runge-Kutta updates
system_op, e_p = sys_op_extended(aux, x, y, z, param_fun, system, fd_order, src_spat, src_deriv, t, dt, t_var, e_p)
aux = [aux[i] + mu * dt * system_op[i] for i in n_eq]
system_op, e_p = sys_op_extended(aux, x, y, z, param_fun, system, fd_order, src_spat, src_deriv, t, dt, t_var, e_p)
aux = [aux[i] + mu * dt * system_op[i] for i in n_eq]

# Compute final approximation
approx = [approx[i] + aux[i] * alpha[degree - 1] for i in n_eq]

# Generate final PDE system
return [dv.Eq(U0[i].forward, approx[i]) for i in n_eq]

# Create temporary Functions to hold each stage
# k = [Array(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array
k = [Function(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', grid=grid, space_order=u.space_order, dtype=u.dtype)
for i in range(self.s)]

stage_eqs = []
stage_eqs = [Eq(k[i], u[i]) for i in range(n_eq)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be tidier using zip

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

[stage_eqs.append(Eq(u[i].forward, u[i]*alpha[0])) for i in range(n_eq)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't love this one-lining. Make this an explicit for loop, or use stage_eqs.extend() instead

Copy link
Copy Markdown
Author

@fernanvr fernanvr Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


# Build each stage
for i in range(self.s):
u_temp = u + dt * sum(aij * kj for aij, kj in zip(self.a[i][:i], k[:i]))
t_shift = t + self.c[i] * dt
for i in range(1, self.deg-1):
[stage_eqs.append(Eq(k_old[j], k[j])) for j in range(n_eq)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, purge this style of one-lining throughout, and use .extend instead

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

src_lhs, e_p = self.source_inclusion(u, k_old, src_index, src_deriv, e_p, t, dt, mu, n_eq)
[stage_eqs.append(Eq(k[j], k_old[j]+mu*dt*src_lhs[j])) for j in range(n_eq)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of these comprehensions should use zip rather than indexing into lists/arrays

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

[stage_eqs.append(Eq(u[j].forward, u[j].forward+k[j]*alpha[i])) for j in range(n_eq)]

# Evaluate RHS at intermediate value
stage_rhs = uxreplace(rhs, {u: u_temp, t: t_shift})
stage_eqs.append(Eq(k[i], stage_rhs))
# Final Runge-Kutta updates
[stage_eqs.append(Eq(k_old[j], k[j])) for j in range(n_eq)]
src_lhs, e_p = self.source_inclusion(u, k_old, src_index, src_deriv, e_p, t, dt, mu, n_eq)
[stage_eqs.append(Eq(k[j], k_old[j]+mu*dt*src_lhs[j])) for j in range(n_eq)]

# Final update: u.forward = u + dt * sum(b_i * k_i)
u_next = u + dt * sum(bi * ki for bi, ki in zip(self.b, k))
stage_eqs.append(Eq(u.forward, u_next))
[stage_eqs.append(Eq(k_old[j], k[j])) for j in range(n_eq)]
src_lhs, _ = self.source_inclusion(u, k_old, src_index, src_deriv, e_p, t, dt, mu, n_eq)
[stage_eqs.append(Eq(k[j], k_old[j]+mu*dt*src_lhs[j])) for j in range(n_eq)]

# Compute final approximation
[stage_eqs.append(Eq(u[j].forward, u[j].forward+k[j]*alpha[self.deg-1])) for j in range(n_eq)]

return stage_eqs

Expand Down
Loading
Loading