Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 31 additions & 70 deletions pipt/misc_tools/data_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
__all__ = [
'combine_ensemble_predictions',
'en_pred_to_pred_data',
'melt_adjoint_to_sensitivity',
'combine_ensemble_dataframes',
'combine_adjoint_ensemble',
'merge_dataframes',
'multilevel_to_singlelevel_columns',
'dataframe_to_series',
'series_to_dataframe',
'series_to_matrix',
'dataframe_to_matrix',
'multilevel_to_singlelevel_columns'
'dataframe_to_matrix'
]


Expand Down Expand Up @@ -140,42 +138,7 @@ def en_pred_to_pred_data(en_pred):
return pred_data


def melt_adjoint_to_sensitivity(adjoint: pd.DataFrame, datatype: list, idX: dict):

adj_datatype = adjoint.columns.levels[0]
adj_params = adjoint.columns.levels[1]

adj_datatype = sorted(adj_datatype, key=lambda x: datatype.index(x))
adj_params = sorted(adj_params, key=lambda x: list(idX.keys()).index(x))

sens = pd.DataFrame(columns=adj_datatype, index=adjoint.index)
for idx in sens.index:
for dkey in adj_datatype:
arr = np.array([])
for param in adj_params:

if not isinstance(adjoint.at[idx, (dkey, param)], np.ndarray):
if np.isnan(adjoint.at[idx, (dkey, param)]):
dim = idX[param]
dim = dim[1] - dim[0]
arr = np.append(arr, np.zeros(dim))
else:
arr = np.append(arr, np.array([adjoint.at[idx, (dkey, param)]]))

else:
a = adjoint.at[idx, (dkey, param)]
a = np.where(np.isnan(a), 0, a)
arr = np.append(arr, a)

sens.at[idx, dkey] = arr

# Melt
sens = sens.melt(ignore_index=False)
sens.rename(columns={'variable': 'datatype', 'value': 'adjoint'}, inplace=True)
return sens


def combine_ensemble_dataframes(en_dfs: list):
def merge_dataframes(en_dfs: list[pd.DataFrame]) -> pd.DataFrame:
'''
Combine a list of DataFrames (one per ensemble member) into a single DataFrame
where each cell contains an array of ensemble values.
Expand All @@ -193,32 +156,36 @@ def combine_ensemble_dataframes(en_dfs: list):
values = []
for dfn in en_dfs:
values.append(dfn.at[idx, col])
df.at[idx, col] = np.array(values).squeeze()

df.at[idx, col] = np.array(values).squeeze().T
return df

def combine_adjoint_ensemble(en_adj, datatype: list, idX: dict):

adjoints = [melt_adjoint_to_sensitivity(adj, datatype, idX) for adj in en_adj]
def multilevel_to_singlelevel_columns(df: pd.DataFrame) -> pd.DataFrame:
"""
Convert a MultiIndex-column DataFrame with structure (key, param)
into a DataFrame with one column per key, where the value is
the concatenation of all param-arrays for that key.
"""
result = {}

index = adjoints[0].index
index_name = adjoints[0].index.name
keys = adjoints[0]['datatype'].values
keys = sorted(keys, key=lambda x: datatype.index(x))
# Top-level keys (level 0 of MultiIndex), preserving first appearance order
keys = pd.Index(df.columns.get_level_values(0)).unique()

#df = pd.DataFrame(columns=['datatype', 'adjoint'], index=index, dtype=object)
for key in keys:
# Extract all columns for this key → list of arrays per row
param_arrays = df[key] # this is a sub-dataframe for this key

data = {'datatype': [], 'adjoint': []}
for i, idx in enumerate(index):
data['datatype'].append(keys[i])
matrix = []
for adj in adjoints:
matrix.append(adj.iloc[i]['adjoint'])
data['adjoint'].append(np.array(matrix).T) # Transpose to get correct shape (n_param, n_ensembles)
# For each row, concatenate arrays from all params
concatenated = [
np.concatenate(param_arrays.iloc[i].values)
for i in range(len(df))
]

df = pd.DataFrame(data, index=index)
df.index.name = index_name
return df
result[key] = concatenated

df_new = pd.DataFrame(result, index=df.index)
df_new.index.name = df.index.name
return df_new


def dataframe_to_series(df):
mult_index = []
Expand Down Expand Up @@ -250,16 +217,10 @@ def dataframe_to_matrix(df):
series = dataframe_to_series(df)
return series_to_matrix(series)

def multilevel_to_singlelevel_columns(df):
cols = df.columns.get_level_values(0).unique()
parms = df.columns.get_level_values(1).unique()

df_new = pd.DataFrame(index=df.index)
for col in cols:
df_new[col] = np.concatenate([df[(col, param)].values for param in parms])

return df_new







Expand Down
48 changes: 26 additions & 22 deletions pipt/update_schemes/enrml.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def calc_analysis(self):
)

# Store the (mean) data misfit (also for conv. check)
self.ensemble_misfit = data_misfit
self.data_misfit = np.mean(data_misfit)
self.prior_data_misfit = np.mean(data_misfit)
self.data_misfit_std = np.std(data_misfit)
Expand All @@ -149,8 +150,8 @@ def calc_analysis(self):

# Check for adjoint
if hasattr(self, 'adjoints'):
enAdj = dtools.combine_ensemble_dataframes(self.adjoints)
enAdj = dtools.dataframe_to_matrix(enAdj) # Shape (nd, ne, nx)
enAdj = dtools.merge_dataframes(self.adjoints)
enAdj = dtools.dataframe_to_matrix(enAdj) # Shape (nd, nx, ne)
else:
enAdj = None

Expand Down Expand Up @@ -209,7 +210,7 @@ def check_convergence(self):
# data instead.

data_misfit = at.calc_objectivefun(self.enObs, enPred, self.cov_data)

self.ensemble_misfit = data_misfit
self.data_misfit = np.mean(data_misfit)
self.data_misfit_std = np.std(data_misfit)

Expand All @@ -229,11 +230,13 @@ def check_convergence(self):

if self.data_misfit >= self.prev_data_misfit:
success = False
self.log_update(success=success)
self.logger(
f'Iterations have converged after {self.iteration} iterations. Objective function reduced '
f'from {self.prior_data_misfit:0.1f} to {self.prev_data_misfit:0.1f}'
)
else:
self.log_update(success=True)
self.logger.info(
f'Iterations have converged after {self.iteration} iterations. Objective function reduced '
f'from {self.prior_data_misfit:0.1f} to {self.data_misfit:0.1f}'
Expand All @@ -249,20 +252,21 @@ def check_convergence(self):
'prev_data_misfit': self.prev_data_misfit,
'lambda': self.lam,
'lambda_stop': self.lam >= self.lam_max}

# Log step
self.log_update(success=success)


###############################################
##### update Lambda step-size values ##########
###############################################
# If reduction in mean data misfit, reduce damping param
if self.data_misfit < self.prev_data_misfit and self.data_misfit_std < self.prev_data_misfit_std:
# Reduce damping parameter (divide calculations for ANALYSISDEBUG purpose)

success = True
self.log_update(success=success)

# Reduce damping parameter
if self.lam > self.lam_min:
self.lam = self.lam / self.gamma
self.logger(f'λ reduced: {self.lam * self.gamma} ──> {self.lam}')
success = True

# Update state ensemble
self.enX = cp.deepcopy(self.enX_temp)
Expand All @@ -274,8 +278,10 @@ def check_convergence(self):


elif self.data_misfit < self.prev_data_misfit and self.data_misfit_std >= self.prev_data_misfit_std:

# accept itaration, but keep lam the same
success = True
self.log_update(success=success)

# Update state ensemble
self.enX = cp.deepcopy(self.enX_temp)
Expand All @@ -286,10 +292,11 @@ def check_convergence(self):
self.current_W = cp.deepcopy(self.W)

else: # Reject iteration, and increase lam
# Increase damping parameter (divide calculations for ANALYSISDEBUG purpose)
success = False
self.log_update(success=success)
self.lam = self.lam * self.gamma
# Increase damping parameter (divide calculations for ANALYSISDEBUG purpose)
self.logger(f'Data misfit increased! λ increased: {self.lam / self.gamma} ──> {self.lam}')
success = False

if not success:
# Reset the objective function after report
Expand All @@ -303,21 +310,18 @@ def log_update(self, success, prior_run=False):
'''
Log the update results in a formatted table.
'''
log_data = {
"Iteration": f'{0 if prior_run else self.iteration}',
"Status": "Success" if (prior_run or success) else "Failed",
"Data Misfit": self.data_misfit,
"λ": self.lam
info = {
"Iteration" : f'{0 if prior_run else self.iteration}',
"Status" : "Success" if (prior_run or success) else "Failed",
"Data Misfit" : self.data_misfit,
"Change (%)" : '',
"λ" : self.lam
}
if not prior_run:
if success:
log_data["Reduction (%)"] = 100 * (1 - self.data_misfit / self.prev_data_misfit)
else:
log_data["Increase (%)"] = 100 * (self.data_misfit / self.prev_data_misfit - 1)
else:
log_data["Reduction (%)"] = 'N/A'
delta = 100*(self.data_misfit / self.prev_data_misfit - 1)
info["Change (%)"] = delta

self.logger(**log_data)
self.logger(**info)



Expand Down
13 changes: 12 additions & 1 deletion pipt/update_schemes/esmda.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pipt.loop.ensemble import Ensemble
import pipt.misc_tools.analysis_tools as at
import pipt.misc_tools.ensemble_tools as entools
import pipt.misc_tools.data_tools as dtools

# import update schemes
from pipt.update_schemes.update_methods_ns.approx_update import approx_update
Expand Down Expand Up @@ -158,12 +159,22 @@ def calc_analysis(self):
if 'localanalysis' in self.keys_da:
self.local_analysis_update()
else:

# Check for adjoint
if hasattr(self, 'adjoints'):
enAdj = dtools.merge_dataframes(self.adjoints)
enAdj = dtools.dataframe_to_matrix(enAdj) # Shape (nd, nx, ne)
else:
enAdj = None

# Perform the update
self.update(
enX = self.enX,
enY = self.enPred,
enE = self.enObs,
prior = self.prior_enX
# kwargs
prior = self.prior_enX,
enAdj = enAdj
)

# Update the state ensemble and weights
Expand Down
Loading