Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
9d2048e
exclude gnome for full downloads if needed
tschaume Mar 5, 2025
505ddfe
query s3 for trajectories
tsmathis Oct 23, 2025
aee0f8c
add deltalake query support
tsmathis Oct 23, 2025
d5a25b1
linting + mistaken sed replace on 'where'
tsmathis Oct 23, 2025
2de051d
return trajectory as pmg dict
tsmathis Oct 23, 2025
7d0b8b7
update trajectory test
tsmathis Oct 23, 2025
7195adf
correct docstrs
tsmathis Oct 23, 2025
33b787f
Merge branch 'main' into deltalake
tschaume Oct 24, 2025
2664fcd
get access controlled batch ids from heartbeat
tsmathis Nov 3, 2025
b498a76
refactor
tsmathis Nov 4, 2025
7da6984
Merge branch 'main' into deltalake
tschaume Nov 4, 2025
948c108
auto dependency upgrades
invalid-email-address Nov 5, 2025
b0aed4f
Update testing.yml
tschaume Nov 5, 2025
a35bcb7
rm overlooked access of removed settings param
tsmathis Nov 5, 2025
9460601
refactor: consolidate requests to heartbeat for meta info
tsmathis Nov 5, 2025
05f1d0e
lint
tsmathis Nov 5, 2025
e685445
fix incomplete docstr
tsmathis Nov 5, 2025
bb0b238
typo
tsmathis Nov 5, 2025
dc0c949
Merge branch 'main' into deltalake
tsmathis Nov 10, 2025
fb84d73
revert testing endpoint
tsmathis Nov 10, 2025
5bdacf5
no parallel on batch_id_neq_any
tsmathis Nov 10, 2025
7ee5515
more resilient dataset path expansion
tsmathis Nov 12, 2025
ae7674d
missed field annotation update
tsmathis Nov 12, 2025
5538c74
coerce Path to str for deltalake lib
tsmathis Nov 12, 2025
f39c0d3
flush based on bytes
tsmathis Nov 14, 2025
a965255
iterate over individual rows for local dataset
tsmathis Nov 14, 2025
03b38e7
missed bounds check for updated iteration behavior
tsmathis Nov 14, 2025
3a44b4f
opt for module level logging over warnings lib
tsmathis Nov 14, 2025
b2a832f
lint
tsmathis Nov 14, 2025
4b4af48
Merge branch 'main' into deltalake
tsmathis Feb 9, 2026
9cf0713
missed during merge-conflict resolution
tsmathis Feb 9, 2026
ff17bea
bump deltalake
tsmathis Feb 9, 2026
cd6e4a4
explicit casts for arrow types for data read from delta
tsmathis Feb 9, 2026
0cf6a40
auto dependency upgrades
invalid-email-address Feb 9, 2026
7284d74
raise warnings for pythonic usage of MPDatasets
tsmathis Feb 9, 2026
961e21c
Automated dependency upgrades (#1058)
tsmathis Feb 9, 2026
e09fd48
incomplete docstr for MPDataset
tsmathis Feb 9, 2026
92f88ac
fix get_trajectory helper func + test
tsmathis Feb 9, 2026
d2c651f
missed passing mpdataset kwargs to lazy subresters on init
tsmathis Feb 10, 2026
2f3960e
more ergonomic count w/ updated deltalake
tsmathis Feb 10, 2026
88a7803
more idiomatic string formatting
tsmathis Feb 10, 2026
4291b38
merge conflicts
esoteric-ephemera Feb 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 181 additions & 19 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import itertools
import os
import platform
import shutil
import sys
import warnings
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
Expand All @@ -18,15 +19,13 @@
from importlib.metadata import PackageNotFoundError, version
from json import JSONDecodeError
from math import ceil
from typing import (
TYPE_CHECKING,
ForwardRef,
Optional,
get_args,
)
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
from urllib.parse import quote, urljoin

import pyarrow as pa
import pyarrow.dataset as ds
import requests
from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake
from emmet.core.utils import jsanitize
from pydantic import BaseModel, create_model
from requests.adapters import HTTPAdapter
Expand All @@ -36,7 +35,7 @@
from urllib3.util.retry import Retry

from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client.core.utils import load_json, validate_ids
from mp_api.client.core.utils import MPDataset, load_json, validate_ids

try:
import boto3
Expand Down Expand Up @@ -71,6 +70,7 @@ class BaseRester:
document_model: type[BaseModel] | None = None
supports_versions: bool = False
primary_key: str = "material_id"
delta_backed: bool = False

def __init__(
self,
Expand All @@ -85,6 +85,8 @@ def __init__(
timeout: int = 20,
headers: dict | None = None,
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
local_dataset_cache: str | os.PathLike = SETTINGS.LOCAL_DATASET_CACHE,
force_renew: bool = False,
):
"""Initialize the REST API helper class.

Expand Down Expand Up @@ -116,6 +118,9 @@ def __init__(
timeout: Time in seconds to wait until a request timeout error is thrown
headers: Custom headers for localhost connections.
mute_progress_bars: Whether to disable progress bars.
local_dataset_cache: Target directory for downloading full datasets. Defaults
to 'materialsproject_datasets' in the user's home directory
force_renew: Option to overwrite existing local dataset
"""
# TODO: think about how to migrate from PMG_MAPI_KEY
self.api_key = api_key or os.getenv("MP_API_KEY")
Expand All @@ -129,6 +134,8 @@ def __init__(
self.timeout = timeout
self.headers = headers or {}
self.mute_progress_bars = mute_progress_bars
self.local_dataset_cache = local_dataset_cache
self.force_renew = force_renew
self.db_version = BaseRester._get_database_version(self.endpoint)

if self.suffix:
Expand Down Expand Up @@ -356,10 +363,7 @@ def _patch_resource(
raise MPRestError(str(ex))

def _query_open_data(
self,
bucket: str,
key: str,
decoder: Callable | None = None,
self, bucket: str, key: str, decoder: Callable | None = None
) -> tuple[list[dict] | list[bytes], int]:
"""Query and deserialize Materials Project AWS open data s3 buckets.

Expand Down Expand Up @@ -463,6 +467,12 @@ def _query_resource(
url += "/"

if query_s3:
pbar_message = ( # type: ignore
f"Retrieving {self.document_model.__name__} documents" # type: ignore
if self.document_model is not None
else "Retrieving documents"
)

db_version = self.db_version.replace(".", "-")
if "/" not in self.suffix:
suffix = self.suffix
Expand All @@ -473,15 +483,168 @@ def _query_resource(
suffix = infix if suffix == "core" else suffix
suffix = suffix.replace("_", "-")

# Paginate over all entries in the bucket.
# TODO: change when a subset of entries needed from DB
# Check if user has access to GNoMe
# temp suppress tqdm
re_enable = not self.mute_progress_bars
self.mute_progress_bars = True
has_gnome_access = bool(
self._submit_requests(
url=urljoin(
"https://api.materialsproject.org/", "materials/summary/"
Comment thread
tschaume marked this conversation as resolved.
Outdated
),
criteria={
"batch_id": "gnome_r2scan_statics",
"_fields": "material_id",
},
use_document_model=False,
num_chunks=1,
chunk_size=1,
timeout=timeout,
)
.get("meta", {})
.get("total_doc", 0)
)
self.mute_progress_bars = not re_enable

if "tasks" in suffix:
bucket_suffix, prefix = "parsed", "tasks_atomate2"
bucket_suffix, prefix = ("parsed", "core/tasks/")
else:
bucket_suffix = "build"
prefix = f"collections/{db_version}/{suffix}"

bucket = f"materialsproject-{bucket_suffix}"

if self.delta_backed:
target_path = (
self.local_dataset_cache + f"/{bucket_suffix}/{prefix}"
)
os.makedirs(target_path, exist_ok=True)

if DeltaTable.is_deltatable(target_path):
if self.force_renew:
shutil.rmtree(target_path)
warnings.warn(
f"Regenerating {suffix} dataset at {target_path}...",
MPLocalDatasetWarning,
)
os.makedirs(target_path, exist_ok=True)
else:
warnings.warn(
f"Dataset for {suffix} already exists at {target_path}, delete or move existing dataset "
"or re-run search query with MPRester(force_renew=True)",
MPLocalDatasetWarning,
)

return {
"data": MPDataset(
path=target_path,
document_model=self.document_model,
use_document_model=self.use_document_model,
)
}

tbl = DeltaTable(
f"s3a://{bucket}/{prefix}",
storage_options={
"AWS_SKIP_SIGNATURE": "true",
"AWS_REGION": "us-east-1",
},
)

controlled_batch_str = ",".join(
[f"'{tag}'" for tag in SETTINGS.ACCESS_CONTROLLED_BATCH_IDS]
Comment thread
tschaume marked this conversation as resolved.
Outdated
)

predicate = (
" WHERE batch_id NOT IN (" # don't delete leading space
+ controlled_batch_str
+ ")"
if not has_gnome_access
else ""
)

builder = QueryBuilder().register("tbl", tbl)

# Setup progress bar
num_docs_needed = pa.table(
builder.execute("SELECT COUNT(*) FROM tbl").read_all()
)[0][0].as_py()

# TODO: Update tasks (+ others?) resource to have emmet-api BatchIdQuery operator
# -> need to modify BatchIdQuery operator to handle root level
# batch_id, not only builder_meta.batch_id
# if not has_gnome_access:
# num_docs_needed = self.count(
# {"batch_id_neq_any": SETTINGS.ACCESS_CONTROLLED_BATCH_IDS}
# )

pbar = (
tqdm(
desc=pbar_message,
total=num_docs_needed,
)
if not self.mute_progress_bars
else None
)

iterator = builder.execute("SELECT * FROM tbl" + predicate)

file_options = ds.ParquetFileFormat().make_write_options(
compression="zstd"
)

def _flush(accumulator, group):
ds.write_dataset(
accumulator,
base_dir=target_path,
format="parquet",
basename_template=f"group-{group}-"
+ "part-{i}.zstd.parquet",
existing_data_behavior="overwrite_or_ignore",
max_rows_per_group=1024,
file_options=file_options,
)

group = 1
size = 0
accumulator = []
for page in iterator:
# arro3 rb to pyarrow rb for compat w/ pyarrow ds writer
accumulator.append(pa.record_batch(page))
page_size = page.num_rows
size += page_size

if pbar is not None:
pbar.update(page_size)

if size >= SETTINGS.DATASET_FLUSH_THRESHOLD:
_flush(accumulator, group)
group += 1
size = 0
accumulator = []
Comment thread
tschaume marked this conversation as resolved.
Outdated

if accumulator:
_flush(accumulator, group + 1)

convert_to_deltalake(target_path)

warnings.warn(
f"Dataset for {suffix} written to {target_path}. It is recommended to optimize "
"the table according to your usage patterns prior to running intensive workloads, "
"see: https://delta-io.github.io/delta-rs/delta-lake-best-practices/#optimizing-table-layout",
MPLocalDatasetWarning,
)

return {
"data": MPDataset(
path=target_path,
document_model=self.document_model,
use_document_model=self.use_document_model,
)
}

# Paginate over all entries in the bucket.
# TODO: change when a subset of entries needed from DB
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

Expand Down Expand Up @@ -518,11 +681,6 @@ def _query_resource(
}

# Setup progress bar
pbar_message = ( # type: ignore
f"Retrieving {self.document_model.__name__} documents" # type: ignore
if self.document_model is not None
else "Retrieving documents"
)
num_docs_needed = int(self.count())
pbar = (
tqdm(
Expand Down Expand Up @@ -1350,3 +1508,7 @@ class MPRestError(Exception):

class MPRestWarning(Warning):
"""Raised when a query is malformed but interpretable."""


class MPLocalDatasetWarning(Warning):
"""Raised when unrecoverable actions are performed on a local dataset."""
14 changes: 14 additions & 0 deletions mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,18 @@ class MAPIClientSettings(BaseSettings):
_MAX_LIST_LENGTH, description="Maximum length of query parameter list"
)

LOCAL_DATASET_CACHE: str = Field(
os.path.expanduser("~") + "/mp_datasets",
Comment thread
tsmathis marked this conversation as resolved.
Outdated
description="Target directory for downloading full datasets",
)

DATASET_FLUSH_THRESHOLD: int = Field(
100000,
Comment thread
tsmathis marked this conversation as resolved.
Outdated
description="Threshold number of rows to accumulate in memory before flushing dataset to disk",
)

ACCESS_CONTROLLED_BATCH_IDS: list[str] = Field(
Comment thread
tschaume marked this conversation as resolved.
Outdated
["gnome_r2scan_statics"], description="Batch ids with access restrictions"
)

model_config = SettingsConfigDict(env_prefix="MPRESTER_")
64 changes: 64 additions & 0 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

import re
from functools import cached_property
from itertools import chain
from typing import TYPE_CHECKING, Literal

import orjson
import pyarrow.dataset as ds
from deltalake import DeltaTable
from emmet.core import __version__ as _EMMET_CORE_VER
from monty.json import MontyDecoder
from packaging.version import parse as parse_version
from pydantic._internal._model_construction import ModelMetaclass

from mp_api.client.core.settings import MAPIClientSettings

Expand Down Expand Up @@ -124,3 +129,62 @@ def validate_monty(cls, v, _):
monty_cls.validate_monty_v2 = classmethod(validate_monty)

return monty_cls


class MPDataset:
def __init__(self, path, document_model, use_document_model):
"""Convenience wrapper for pyarrow datasets stored on disk."""
self._start = 0
self._path = path
self._document_model = document_model
self._dataset = ds.dataset(path)
self._row_groups = list(
chain.from_iterable(
[
fragment.split_by_row_group()
for fragment in self._dataset.get_fragments()
]
)
)
self._use_document_model = use_document_model

@property
def pyarrow_dataset(self) -> ds.Dataset:
return self._dataset

@property
def pydantic_model(self) -> ModelMetaclass:
return self._document_model

@property
def use_document_model(self) -> bool:
return self._use_document_model

@use_document_model.setter
def use_document_model(self, value: bool):
self._use_document_model = value

@cached_property
def delta_table(self) -> DeltaTable:
return DeltaTable(self._path)

@cached_property
def num_chunks(self) -> int:
return len(self._row_groups)

def __getitem__(self, idx):
return list(
map(
lambda x: self._document_model(**x) if self._use_document_model else x,
self._row_groups[idx].to_table().to_pylist(maps_as_pydicts="strict"),
)
)

def __len__(self) -> int:
Comment thread
esoteric-ephemera marked this conversation as resolved.
return self.num_chunks

def __iter__(self):
Comment thread
esoteric-ephemera marked this conversation as resolved.
current = self._start
while current < self.num_chunks:
yield self[current]
current += 1
Loading
Loading