Skip to content

Commit 1832a1a

Browse files
authored
add sel argument (#32)
1 parent b22fe4c commit 1832a1a

3 files changed

Lines changed: 52 additions & 6 deletions

File tree

tests/test_20_open_dataset.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from collections.abc import Hashable
12
from pathlib import Path
3+
from typing import Any
24

35
import pytest
46
import xarray as xr
@@ -103,3 +105,34 @@ def test_combine_coords(tmp_path: Path, index_node: str) -> None:
103105
)
104106
assert set(ds.coords) == {"areacella", "lat", "lon", "experiment_id", "orog"}
105107
assert not ds.data_vars
108+
109+
110+
@pytest.mark.parametrize(
111+
"sel,expected_size",
112+
[
113+
({}, 12),
114+
({"time": "2019-01"}, 1),
115+
({"time": {"slice": ["2019-01", "2019-02"]}}, 2),
116+
],
117+
)
118+
def test_time_selection(
119+
tmp_path: Path,
120+
index_node: str,
121+
sel: dict[Hashable, Any],
122+
expected_size: int,
123+
) -> None:
124+
esgpull_path = tmp_path / "esgpull"
125+
selection = {
126+
"query": [
127+
'"tas_Amon_EC-Earth3-CC_ssp245_r1i1p1f1_gr_201901-201912.nc"',
128+
]
129+
}
130+
ds = xr.open_dataset(
131+
selection, # type: ignore[arg-type]
132+
esgpull_path=esgpull_path,
133+
engine="esgf",
134+
index_node=index_node,
135+
chunks={},
136+
sel=sel,
137+
)
138+
assert ds.sizes["time"] == expected_size

xarray_esgf/client.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Callable, Hashable, Iterable
66
from functools import cached_property
77
from pathlib import Path
8-
from typing import Literal, get_args
8+
from typing import Any, Literal, get_args
99

1010
import tqdm
1111
import xarray as xr
@@ -117,10 +117,15 @@ def download(self) -> list[File]:
117117
def _open_datasets(
118118
self,
119119
concat_dims: DATASET_ID_KEYS | Iterable[DATASET_ID_KEYS] | None,
120-
drop_variables: str | Iterable[str] | None = None,
121-
download: bool = False,
122-
show_progress: bool = True,
120+
drop_variables: str | Iterable[str] | None,
121+
download: bool,
122+
show_progress: bool,
123+
sel: dict[Hashable, Any],
123124
) -> dict[str, Dataset]:
125+
sel = {
126+
k: slice(*v["slice"]) if isinstance(v, dict) else v for k, v in sel.items()
127+
}
128+
124129
if isinstance(concat_dims, str):
125130
concat_dims = [concat_dims]
126131
concat_dims = concat_dims or []
@@ -139,6 +144,7 @@ def _open_datasets(
139144
drop_variables=drop_variables,
140145
storage_options={"ssl": self.verify_ssl},
141146
)
147+
ds = ds.sel({k: v for k, v in sel.items() if k in ds.dims})
142148
grouped_objects[file.dataset_id].append(ds.drop_encoding())
143149

144150
combined_datasets = {}
@@ -173,9 +179,14 @@ def open_dataset(
173179
drop_variables: str | Iterable[str] | None = None,
174180
download: bool = False,
175181
show_progress: bool = True,
182+
sel: dict[Hashable, Any] | None = None,
176183
) -> Dataset:
177184
combined_datasets = self._open_datasets(
178-
concat_dims, drop_variables, download, show_progress
185+
concat_dims=concat_dims,
186+
drop_variables=drop_variables,
187+
download=download,
188+
show_progress=show_progress,
189+
sel=sel or {},
179190
)
180191

181192
obj = xr.combine_by_coords(

xarray_esgf/engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Iterable
1+
from collections.abc import Hashable, Iterable
22
from pathlib import Path
33
from typing import Any
44

@@ -22,6 +22,7 @@ def open_dataset( # type: ignore[override]
2222
concat_dims: DATASET_ID_KEYS | Iterable[DATASET_ID_KEYS] | None = None,
2323
download: bool = False,
2424
show_progress: bool = True,
25+
sel: dict[Hashable, Any] | None = None,
2526
) -> Dataset:
2627
client = Client(
2728
selection=filename_or_obj,
@@ -36,6 +37,7 @@ def open_dataset( # type: ignore[override]
3637
drop_variables=drop_variables,
3738
download=download,
3839
show_progress=show_progress,
40+
sel=sel,
3941
)
4042

4143
open_dataset_parameters = (

0 commit comments

Comments
 (0)