Skip to content

Commit 69a0b13

Browse files
committed
Add unit tests to ensure dataset copy works
1 parent ebc9737 commit 69a0b13

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

conda_package/tests/test_conversion.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from mpas_tools.io import write_netcdf
7-
from mpas_tools.mesh.conversion import convert, cull, mask
7+
from mpas_tools.mesh.conversion import _masks_to_int, convert, cull, mask
88
from mpas_tools.mesh.spherical import recompute_angle_edge
99

1010
from .util import get_test_data_file
@@ -47,6 +47,27 @@ def test_conversion_angle_edge():
4747
assert np.max(np.abs(angle_diff)) < 1.0e-10
4848

4949

50+
def test_masks_to_int_dataset_copy():
51+
ds_in = xarray.Dataset(
52+
data_vars={
53+
'regionCellMasks': (
54+
('nCells', 'nRegions'),
55+
np.array([[True, False], [False, True]]),
56+
),
57+
'cullCell': (('nCells',), np.array([False, True])),
58+
'xCell': (('nCells',), np.array([1.0, 2.0])),
59+
},
60+
attrs={'meshName': 'unit-test'},
61+
)
62+
63+
ds_out = _masks_to_int(ds_in)
64+
65+
assert ds_out.regionCellMasks.dtype == np.int32
66+
assert ds_out.cullCell.dtype == np.int32
67+
assert ds_out.attrs == ds_in.attrs
68+
assert np.array_equal(ds_out.xCell.values, ds_in.xCell.values)
69+
70+
5071
if __name__ == '__main__':
5172
test_conversion()
5273
test_conversion_angle_edge()

conda_package/tests/test_viz_transects.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#!/usr/bin/env python
2+
import numpy as np
23
import xarray as xr
34

45
from mpas_tools.cime.constants import constants
56
from mpas_tools.mesh.conversion import convert, cull
67
from mpas_tools.planar_hex import make_planar_hex_mesh
78
from mpas_tools.translate import center
89
from mpas_tools.viz.mesh_to_triangles import mesh_to_triangles
10+
from mpas_tools.viz.transect.horiz import _fix_periodic_tris
911
from mpas_tools.viz.transect import (
1012
find_planar_transect_cells_and_weights,
1113
find_spherical_transect_cells_and_weights,
@@ -130,6 +132,24 @@ def test_find_planar_transect_cells_and_weights():
130132
assert var in ds_transect.data_vars
131133

132134

135+
def test_fix_periodic_tris_dataset_copy():
136+
ds_tris = xr.Dataset(
137+
data_vars={
138+
'xNode': (
139+
('nTriangles', 'nNodes'),
140+
np.array([[0.0, 8.0, -8.0]]),
141+
),
142+
}
143+
)
144+
145+
ds_new = _fix_periodic_tris(ds_tris, 'xNode', period=10.0)
146+
147+
assert ds_new.sizes['nTriangles'] == 3
148+
assert np.allclose(ds_new.xNode.values[0, :], [0.0, -2.0, 2.0])
149+
assert np.allclose(ds_new.xNode.values[1, :], [10.0, 8.0, 12.0])
150+
assert np.allclose(ds_new.xNode.values[2, :], [-10.0, -12.0, -8.0])
151+
152+
133153
def _get_triangles():
134154
ds_mesh = xr.open_dataset(get_test_data_file('mesh.QU.1920km.151026.nc'))
135155
earth_radius = constants['SHR_CONST_REARTH']

0 commit comments

Comments
 (0)