Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
cdb9d9d
Add save_metric=1 to adapt sampler args
amas0 Nov 18, 2025
62fefae
Add metric files to RunSet
amas0 Nov 18, 2025
c7fef6b
Add metric files runset tests
amas0 Nov 18, 2025
838b78c
Add initial metric parsing logic
amas0 Nov 18, 2025
f7a9ae9
Fix string name collisions cmdstan args test
amas0 Nov 18, 2025
c47791b
Lazily load metric info from file
amas0 Dec 2, 2025
4823796
Fix field_validator to be classmethod
amas0 Dec 2, 2025
d23a9d6
Properly handle one process per chain metric output
amas0 Dec 2, 2025
66512d5
Remove _step_size initialization from assemble_draws
amas0 Dec 2, 2025
4a40ef7
Allow stepsize to be nan
amas0 Dec 2, 2025
cb493b5
Short-circuit metric properties to None when fixed param
amas0 Dec 2, 2025
84ae036
Only enable save_metric=1 when adapt engaged
amas0 Dec 2, 2025
063dfb9
Add metric file output for testing CmdStanMCMC construction from outp…
amas0 Dec 2, 2025
2d076af
Merge branch 'stan-dev:develop' into enable-save-metric
amas0 Dec 2, 2025
b138308
Add metric files for runset-big
amas0 Dec 2, 2025
29ddbfd
Fix metric output filenames test to reflect one proc per chain
amas0 Dec 2, 2025
4502aef
Remove functionality and tests for parsing metric info from CSV
amas0 Dec 2, 2025
d80461d
Add pydantic as a dependency
amas0 Dec 2, 2025
0f5dab8
Add tests of MetricInfo validators
amas0 Dec 2, 2025
82a85d2
Remove unused chain_id from MetricInfo
amas0 Dec 12, 2025
1635915
Remove stringified type hints
amas0 Dec 12, 2025
1f0f0a9
Clarify arbitrary_types_allowed usage
amas0 Dec 12, 2025
5fcee74
Remove _metric_info_parsed
amas0 Dec 15, 2025
2497754
Add tests for invalid metric type
amas0 Dec 16, 2025
37080af
Convert MetricInfo.inv_metric to native Python types
amas0 Dec 16, 2025
b5926c9
Fixup mypy issue in tests
amas0 Dec 17, 2025
915eef5
Convert to list for test clarity
amas0 Dec 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def compose(self, idx: int, cmd: list[str]) -> list[str]:
cmd.append(f'window={self.adapt_metric_window}')
if self.adapt_step_size is not None:
cmd.append('term_buffer={}'.format(self.adapt_step_size))
if self.adapt_engaged:
cmd.append('save_metric=1')
# End adapt subsection

if self.num_chains > 1:
cmd.append('num_chains={}'.format(self.num_chains))

Expand Down
69 changes: 47 additions & 22 deletions cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
stancsv,
)

from .metadata import InferenceMetadata
from .metadata import InferenceMetadata, MetricInfo
from .runset import RunSet


Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(
# info from CSV values, instantiated lazily
self._draws: np.ndarray = np.array(())
# only valid when not is_fixed_param
self._metric_type: str | None = None
self._metric: np.ndarray = np.array(())
self._step_size: np.ndarray = np.array(())
self._divergences: np.ndarray = np.zeros(self.runset.chains, dtype=int)
Expand All @@ -92,6 +93,8 @@ def __init__(
# info from CSV header and initial and final comment blocks
config = self._validate_csv_files()
self._metadata: InferenceMetadata = InferenceMetadata(config)
self._chain_metric_info: list[MetricInfo] = []
self._metric_info_parsed: bool = False
Comment thread
WardBrian marked this conversation as resolved.
Outdated
if not self._is_fixed_param:
self._check_sampler_diagnostics()

Expand Down Expand Up @@ -216,11 +219,13 @@ def metric_type(self) -> str | None:
to CmdStan arg 'metric'.
When sampler algorithm 'fixed_param' is specified, metric_type is None.
"""
return (
self._metadata.cmdstan_config['metric']
if not self._is_fixed_param
else None
)
if self._is_fixed_param:
return None

if not self._metric_info_parsed:
self._parse_metric_info()

return self._metric_type

@property
def inv_metric(self) -> np.ndarray | None:
Expand All @@ -230,10 +235,15 @@ def inv_metric(self) -> np.ndarray | None:
a ``nchains x nparams x nparams`` array when metric_type is 'dense_e',
or ``None`` when metric_type is 'unit_e' or algorithm is 'fixed_param'.
"""
if self._is_fixed_param or self.metric_type == 'unit_e':
if self._is_fixed_param:
return None

if not self._metric_info_parsed:
self._parse_metric_info()

if self.metric_type == 'unit_e':
return None

self._assemble_draws()
return self._metric

@property
Expand All @@ -242,8 +252,13 @@ def step_size(self) -> np.ndarray | None:
Step size used by sampler for each chain.
When sampler algorithm 'fixed_param' is specified, step size is None.
"""
self._assemble_draws()
return self._step_size if not self._is_fixed_param else None
if self._is_fixed_param:
return None

if not self._metric_info_parsed:
self._parse_metric_info()

return self._step_size

@property
def thin(self) -> int:
Expand Down Expand Up @@ -382,6 +397,27 @@ def _validate_csv_files(self) -> dict[str, Any]:
self._max_treedepths[i] = drest['ct_max_treedepth']
return dzero

def _parse_metric_info(self) -> None:
"""Extracts metric type, inv_metric, and step size information from the
parsed metric JSONs."""
self._chain_metric_info = [
MetricInfo.from_json(mf, chain_id)
for mf, chain_id in zip(
self.runset.metric_files, self.runset.chain_ids
)
]
metric_types = {cmi.metric_type for cmi in self._chain_metric_info}
if len(metric_types) != 1:
raise ValueError("Inconsistent metric types found across chains")
self._metric_type = self._chain_metric_info[0].metric_type
self._metric = np.asarray(
[cmi.inv_metric for cmi in self._chain_metric_info]
)
Comment thread
amas0 marked this conversation as resolved.
self._step_size = np.asarray(
[cmi.stepsize for cmi in self._chain_metric_info]
)
self._metric_info_parsed = True

def _check_sampler_diagnostics(self) -> None:
"""
Warn if any iterations ended in divergences or hit maxtreedepth.
Expand Down Expand Up @@ -424,13 +460,11 @@ def _assemble_draws(self) -> None:
dtype=np.float64,
order='F',
)
self._step_size = np.empty(self.chains, dtype=np.float64)

mass_matrix_per_chain = []
for chain in range(self.chains):
try:
(
comments,
_,
header,
draws,
) = stancsv.parse_comments_header_and_draws(
Expand All @@ -443,20 +477,11 @@ def _assemble_draws(self) -> None:
draws_np = np.empty((0, n_cols))

self._draws[:, chain, :] = draws_np
if not self._is_fixed_param:
(
self._step_size[chain],
mass_matrix,
) = stancsv.parse_hmc_adaptation_lines(comments)
mass_matrix_per_chain.append(mass_matrix)
except Exception as exc:
raise ValueError(
f"Parsing output from {self.runset.csv_files[chain]} failed"
) from exc

if all(mm is not None for mm in mass_matrix_per_chain):
self._metric = np.array(mass_matrix_per_chain)

assert self._draws is not None

def summary(
Expand Down
58 changes: 57 additions & 1 deletion cmdstanpy/stanfit/metadata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""Container for metadata parsed from the output of a CmdStan run"""

from __future__ import annotations
Comment thread
WardBrian marked this conversation as resolved.

import copy
import json
import math
import os
from typing import Any, Iterator
from typing import Any, Iterator, Literal

import numpy as np
import stanio
from pydantic import BaseModel, Field, field_validator, model_validator

from cmdstanpy.utils import stancsv

Expand Down Expand Up @@ -79,3 +85,53 @@ def stan_vars(self) -> dict[str, stanio.Variable]:
These are the user-defined variables in the Stan program.
"""
return self._stan_vars


class MetricInfo(BaseModel):
"""Structured representation of HMC-NUTS metric information,
as output by CmdStan"""

chain_id: int = Field(gt=0)
Comment thread
amas0 marked this conversation as resolved.
Outdated
stepsize: float
metric_type: Literal["diag_e", "dense_e", "unit_e"]
inv_metric: np.ndarray

model_config = {"arbitrary_types_allowed": True}
Comment thread
amas0 marked this conversation as resolved.
Outdated

@field_validator("inv_metric", mode="before")
@classmethod
def convert_inv_metric(cls, v: Any) -> np.ndarray:
return np.asarray(v)

@field_validator("stepsize")
@classmethod
def validate_stepsize(cls, v: float) -> float:
if not math.isnan(v) and v <= 0:
raise ValueError("stepsize must be greater than 0 or NaN")
return v

@model_validator(mode="after")
def validate_inv_metric_shape(self) -> MetricInfo:
if (
self.metric_type in ("diag_e", "unit_e")
and self.inv_metric.ndim != 1
):
raise ValueError(
"inv_metric must be 1D for diag_e and unit_e metric type"
)
if self.metric_type == "dense_e":
if self.inv_metric.ndim != 2:
raise ValueError("Dense inv_metric must be 2D")
if self.inv_metric.shape[0] != self.inv_metric.shape[1]:
raise ValueError("Dense inv_metric must be square")

return self

@classmethod
def from_json(cls, file: str | os.PathLike, chain_id: int) -> MetricInfo:
"""Parse and validate a metric json given a file path and chain_id"""
with open(file) as f:
info_dict = json.load(f)

info_dict['chain_id'] = chain_id
return cls.model_validate(info_dict) # type: ignore
24 changes: 24 additions & 0 deletions cmdstanpy/stanfit/runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
self._stdout_files, self._profile_files = [], []
self._csv_files, self._diagnostic_files = [], []
self._config_files = []
self._metric_files = []

# per-process output files
if one_process_per_chain and chains > 1:
Expand Down Expand Up @@ -87,6 +88,10 @@ def __init__(
# per-chain output files
if chains == 1:
self._csv_files = [self.gen_file_name(".csv")]
if args.method == Method.SAMPLE:
self._metric_files = [
self.gen_file_name(".json", extra="metric")
]
if args.save_latent_dynamics:
self._diagnostic_files = [
self.gen_file_name(".csv", extra="diagnostic")
Expand All @@ -95,6 +100,20 @@ def __init__(
self._csv_files = [
self.gen_file_name(".csv", id=id) for id in self._chain_ids
]
if args.method == Method.SAMPLE:
if one_process_per_chain:
self._metric_files = [
os.path.join(
self._outdir,
f"{self._base_outfile}_{id}_metric.json",
)
for id in self._chain_ids
]
else:
self._metric_files = [
self.gen_file_name(".json", extra="metric", id=id)
for id in self._chain_ids
]
if args.save_latent_dynamics:
self._diagnostic_files = [
self.gen_file_name(".csv", extra="diagnostic", id=id)
Expand Down Expand Up @@ -222,6 +241,11 @@ def profile_files(self) -> list[str]:
"""List of paths to CmdStan profiler files."""
return self._profile_files

@property
def metric_files(self) -> list[str]:
"""List of paths to CmdStan NUTS-HMC sampler metric files."""
return self._metric_files

def gen_file_name(
self, suffix: str, *, extra: str = "", id: int | None = None
) -> str:
Expand Down
100 changes: 0 additions & 100 deletions cmdstanpy/utils/stancsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,44 +103,6 @@ def csv_bytes_list_to_numpy(
return out


def parse_hmc_adaptation_lines(
comment_lines: list[bytes],
) -> tuple[float | None, npt.NDArray[np.float64] | None]:
"""Extracts step size/mass matrix information from the Stan CSV comment
lines by parsing the adaptation section. If the diag_e metric is used,
the returned mass matrix will be a 1D array of the diagnoal elements,
if the dense_e metric is used, it will be a 2D array representing the
entire matrix, and if unit_e is used then None will be returned.

Returns a (step_size, mass_matrix) tuple"""
step_size, mass_matrix = None, None

cleaned_lines = (ln.lstrip(b"# ") for ln in comment_lines)
in_matrix_block = False
diag_e_metric = False
matrix_lines = []
for line in cleaned_lines:
if in_matrix_block and line.strip():
# Stop when we get to timing block
if line.startswith(b"Elapsed Time"):
break
matrix_lines.append(line)
elif line.startswith(b"Step size"):
_, ss_str = line.split(b" = ")
step_size = float(ss_str)
elif line.startswith(b"Diagonal") or line.startswith(b"Elements"):
in_matrix_block = True
elif line.startswith(b"No free"):
break
elif b"diag_e" in line:
diag_e_metric = True
if matrix_lines:
mass_matrix = csv_bytes_list_to_numpy(matrix_lines)
if diag_e_metric and mass_matrix.shape[0] == 1:
mass_matrix = mass_matrix[0]
return step_size, mass_matrix


def extract_key_val_pairs(
comment_lines: list[bytes], remove_default_text: bool = True
) -> Iterator[tuple[str, str]]:
Expand Down Expand Up @@ -346,67 +308,6 @@ def column_count(ln: bytes) -> int:
)


def raise_on_invalid_adaptation_block(comment_lines: list[bytes]) -> None:
"""Throws ValueErrors if the parsed adaptation block is invalid, e.g.
the metric information is not present, consistent with the rest of
the file, or the step size info cannot be processed."""

def column_count(ln: bytes) -> int:
return ln.count(b",") + 1

ln_iter = enumerate(comment_lines, start=2)
metric = None
for _, line in ln_iter:
if b"metric =" in line:
_, val = line.split(b" = ")
metric = val.replace(b"(Default)", b"").strip().decode()
if b"Adaptation terminated" in line:
break
else: # No adaptation block found
raise ValueError("No adaptation block found, expecting metric")

if metric is None:
raise ValueError("No reported metric found")
# At this point iterator should be in the adaptation block

# Ensure step size exists and is valid float
num, line = next(ln_iter)
if not line.startswith(b"# Step size"):
raise ValueError(
f"line {num}: expecting step size, found:\n\t \"{line.decode()}\""
)
_, step_size = line.split(b" = ")
try:
float(step_size.strip())
except ValueError as exc:
raise ValueError(
f"line {num}: invalid step size: {step_size.decode()}"
) from exc

# Ensure mass matrix valid
num, line = next(ln_iter)
if metric == "unit_e":
return
if not (
(metric == "diag_e" and line.startswith(b"# Diagonal elements of "))
or (metric == "dense_e" and line.startswith(b"# Elements of inverse"))
):
raise ValueError(
f"line {num}: invalid or missing mass matrix specification"
)

# Validating mass matrix shape
_, line = next(ln_iter)
num_unconstrained_params = column_count(line)
if metric == "diag_e":
return
for (num, line), _ in zip(ln_iter, range(1, num_unconstrained_params)):
if column_count(line) != num_unconstrained_params:
raise ValueError(
f"line {num}: invalid or missing mass matrix specification"
)


def parse_timing_lines(
comment_lines: list[bytes],
) -> dict[str, float]:
Expand Down Expand Up @@ -489,7 +390,6 @@ def parse_sampler_metadata_from_csv(
and header
and not is_sneaky_fixed_param(header)
):
raise_on_invalid_adaptation_block(comments)
max_depth: int = config["max_depth"] # type: ignore
max_tree_hits, divs = extract_max_treedepth_and_divergence_counts(
header, draws, max_depth, num_warmup
Expand Down
Loading