From c189ef335ed498769f44f82159613bdff6f99fff Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Sun, 15 Mar 2026 22:12:38 +0100 Subject: [PATCH 1/2] Rename axis parameter to dim on DataArray/Dataset aggregation methods Aligns with xarray's API naming convention. The old axis= keyword is preserved as a deprecated keyword-only argument with FutureWarning. All tests migrated to use the new dim= parameter. --- src/mikeio/dataset/_dataarray.py | 247 +++++++++++++++++++------------ src/mikeio/dataset/_dataset.py | 230 ++++++++++++++++------------ tests/test_dataarray.py | 111 +++++++++++--- tests/test_dataset.py | 103 +++++++++++-- tests/test_dfs2.py | 2 +- tests/test_dfs3.py | 4 +- tests/test_generic.py | 10 +- tests/test_integration.py | 8 +- 8 files changed, 477 insertions(+), 238 deletions(-) diff --git a/src/mikeio/dataset/_dataarray.py b/src/mikeio/dataset/_dataarray.py index 9425bcc37..0fefa0d1b 100644 --- a/src/mikeio/dataset/_dataarray.py +++ b/src/mikeio/dataset/_dataarray.py @@ -68,6 +68,22 @@ DataArrayPlotterLineSpectrum, ) + +def _resolve_deprecated_axis( + dim: int | str | None, + axis: int | str | None, +) -> int | str | None: + """If axis= keyword was used, warn and return it. Otherwise return dim.""" + if axis is not None: + warnings.warn( + "The 'axis' keyword is deprecated. Pass the dimension directly as the first argument.", + FutureWarning, + stacklevel=3, + ) + return axis + return dim + + GeometryType = Union[ Geometry0D, GeometryUndefined, @@ -1134,11 +1150,26 @@ def interp_time( dt=self._dt, ) - def interp_na(self, axis: str = "time", **kwargs: Any) -> DataArray: + def interp_na( + self, + dim: str = "time", + *, + axis: str | None = None, + **kwargs: Any, + ) -> DataArray: """Fill in NaNs by interpolating according to different methods. Wrapper of [](`xarray.DataArray.interpolate_na`) + Parameters + ---------- + dim: str, optional + dimension to interpolate along, by default "time" + axis: str, optional + deprecated, use dim + **kwargs: Any + Additional keyword arguments passed to xarray interpolate_na + Examples -------- @@ -1155,7 +1186,9 @@ def interp_na(self, axis: str = "time", **kwargs: Any) -> DataArray: ``` """ - xr_da = self.to_xarray().interpolate_na(dim=axis, **kwargs) + resolved = _resolve_deprecated_axis(dim, axis) + assert isinstance(resolved, str) + xr_da = self.to_xarray().interpolate_na(dim=resolved, **kwargs) self.values = xr_da.values return self @@ -1282,15 +1315,15 @@ def concat( # ============= Aggregation methods =========== - def max(self, axis: int | str = 0, **kwargs: Any) -> DataArray: - """Max value along an axis. + def max(self, dim: int | str = 0, *, axis: int | str | None = None) -> DataArray: + """Max value along a dimension. Parameters ---------- - axis: (int, str, None), optional - axis number or "time" or "space", by default 0 - **kwargs: Any - Additional keyword arguments + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1302,17 +1335,17 @@ def max(self, axis: int | str = 0, **kwargs: Any) -> DataArray: nanmax : Max values with NaN values removed """ - return self.aggregate(axis=axis, func=np.max, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.max) - def min(self, axis: int | str = 0, **kwargs: Any) -> DataArray: - """Min value along an axis. + def min(self, dim: int | str = 0, *, axis: int | str | None = None) -> DataArray: + """Min value along a dimension. Parameters ---------- - axis: (int, str, None), optional - axis number or "time" or "space", by default 0 - **kwargs: Any - Additional keyword arguments + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1324,17 +1357,19 @@ def min(self, axis: int | str = 0, **kwargs: Any) -> DataArray: nanmin : Min values with NaN values removed """ - return self.aggregate(axis=axis, func=np.min, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.min) - def mean(self, axis: int | str | None = 0, **kwargs: Any) -> DataArray: - """Mean value along an axis. + def mean( + self, dim: int | str | None = 0, *, axis: int | str | None = None + ) -> DataArray: + """Mean value along a dimension. Parameters ---------- - axis: (int, str, None), optional - axis number or "time" or "space", by default 0 - **kwargs: Any - Additional keyword arguments + dim: int, str or None, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1346,17 +1381,17 @@ def mean(self, axis: int | str | None = 0, **kwargs: Any) -> DataArray: nanmean : Mean values with NaN values removed """ - return self.aggregate(axis=axis, func=np.mean, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.mean) - def std(self, axis: int | str = 0, **kwargs: Any) -> DataArray: - """Standard deviation values along an axis. + def std(self, dim: int | str = 0, *, axis: int | str | None = None) -> DataArray: + """Standard deviation along a dimension. Parameters ---------- - axis: (int, str, None), optional - axis number or "time" or "space", by default 0 - **kwargs: Any - Additional keyword arguments + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1368,17 +1403,17 @@ def std(self, axis: int | str = 0, **kwargs: Any) -> DataArray: nanstd : Standard deviation values with NaN values removed """ - return self.aggregate(axis=axis, func=np.std, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.std) - def ptp(self, axis: int | str = 0, **kwargs: Any) -> DataArray: - """Range (max - min) a.k.a Peak to Peak along an axis. + def ptp(self, dim: int | str = 0, *, axis: int | str | None = None) -> DataArray: + """Range (max - min) a.k.a Peak to Peak along a dimension. Parameters ---------- - axis: (int, str, None), optional - axis number or "time" or "space", by default 0 - **kwargs: Any - Additional keyword arguments + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1386,21 +1421,25 @@ def ptp(self, axis: int | str = 0, **kwargs: Any) -> DataArray: array with peak to peak values """ - return self.aggregate(axis=axis, func=np.ptp, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.ptp) def average( - self, weights: np.ndarray, axis: int | str = 0, **kwargs: Any + self, + weights: np.ndarray, + dim: int | str = 0, + *, + axis: int | str | None = None, ) -> DataArray: - """Compute the weighted average along the specified axis. + """Compute the weighted average along the specified dimension. Parameters ---------- - axis: (int, str, None), optional - axis number or "time" or "space", by default weights: np.ndarray weights to apply to the values - **kwargs: Any - Additional keyword arguments + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1417,7 +1456,7 @@ def average( import mikeio da= mikeio.read("../data/HD2D.dfsu")["Current speed"] area = da.geometry.get_element_area() - da.average(axis="space", weights=area) + da.average("space", weights=area) ``` """ @@ -1428,17 +1467,17 @@ def func(x, axis, keepdims): # type: ignore return np.average(x, weights=weights, axis=axis) - return self.aggregate(axis=axis, func=func, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=func) - def nanmax(self, axis: int | str = 0, **kwargs: Any) -> DataArray: - """Max value along an axis (NaN removed). + def nanmax(self, dim: int | str = 0, *, axis: int | str | None = None) -> DataArray: + """Max value along a dimension (NaN removed). Parameters ---------- - axis: (int, str, None), optional - axis number or "time" or "space", by default 0 - **kwargs: Any - Additional keyword arguments + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1450,17 +1489,17 @@ def nanmax(self, axis: int | str = 0, **kwargs: Any) -> DataArray: nanmax : Max values with NaN values removed """ - return self.aggregate(axis=axis, func=np.nanmax, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.nanmax) - def nanmin(self, axis: int | str = 0, **kwargs: Any) -> DataArray: - """Min value along an axis (NaN removed). + def nanmin(self, dim: int | str = 0, *, axis: int | str | None = None) -> DataArray: + """Min value along a dimension (NaN removed). Parameters ---------- - axis: (int, str, None), optional - axis number or "time" or "space", by default 0 - **kwargs: Any - Additional keyword arguments + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1472,17 +1511,19 @@ def nanmin(self, axis: int | str = 0, **kwargs: Any) -> DataArray: nanmin : Min values with NaN values removed """ - return self.aggregate(axis=axis, func=np.nanmin, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.nanmin) - def nanmean(self, axis: int | str | None = 0, **kwargs: Any) -> DataArray: - """Mean value along an axis (NaN removed). + def nanmean( + self, dim: int | str | None = 0, *, axis: int | str | None = None + ) -> DataArray: + """Mean value along a dimension (NaN removed). Parameters ---------- - axis: (int, str, None), optional - axis number or "time" or "space", by default 0 - **kwargs: Any - Additional keyword arguments + dim: int, str or None, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1494,17 +1535,17 @@ def nanmean(self, axis: int | str | None = 0, **kwargs: Any) -> DataArray: mean : Mean values """ - return self.aggregate(axis=axis, func=np.nanmean, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.nanmean) - def nanstd(self, axis: int | str = 0, **kwargs: Any) -> DataArray: - """Standard deviation value along an axis (NaN removed). + def nanstd(self, dim: int | str = 0, *, axis: int | str | None = None) -> DataArray: + """Standard deviation along a dimension (NaN removed). Parameters ---------- - axis: (int, str, None), optional - axis number or "time" or "space", by default 0 - **kwargs: Any - Additional keyword arguments + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1516,24 +1557,25 @@ def nanstd(self, axis: int | str = 0, **kwargs: Any) -> DataArray: std : Standard deviation """ - return self.aggregate(axis=axis, func=np.nanstd, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.nanstd) def aggregate( self, - axis: int | str | None = 0, + dim: int | str | None = 0, func: Callable[..., Any] = np.nanmean, - **kwargs: Any, + *, + axis: int | str | None = None, ) -> DataArray: """Aggregate along an axis. Parameters ---------- - axis: (int, str, None), optional - axis number or "time" or "space", by default 0 + dim: int or str, optional + dimension to aggregate over, 0 (time) by default func: function, optional default np.nanmean - **kwargs: Any - Additional keyword arguments + axis: int or str, optional + deprecated, use dim instead Returns ------- @@ -1546,7 +1588,8 @@ def aggregate( nanmax : Max values with NaN values removed """ - parsed_axis = self._parse_axis(axis) + dim = _resolve_deprecated_axis(dim, axis) + parsed_axis = self._parse_axis(dim) time = self._time_by_agg_axis(self.time, parsed_axis) if isinstance(parsed_axis, int): @@ -1555,14 +1598,12 @@ def aggregate( axes = parsed_axis # type: ignore item = deepcopy(self.item) - if "name" in kwargs: - item.name = kwargs.pop("name") with ( warnings.catch_warnings() ): # there might be all-Nan slices, it is ok, so we ignore them! warnings.simplefilter("ignore", category=RuntimeWarning) - data = func(self.to_numpy(), axis=parsed_axis, keepdims=False, **kwargs) + data = func(self.to_numpy(), axis=parsed_axis, keepdims=False) if parsed_axis == 0 and "time" in self.dims and len(axes) == 1: # Time-only aggregation - preserve geometry @@ -1597,9 +1638,14 @@ def quantile(self, q: float, **kwargs: Any) -> DataArray: ... def quantile(self, q: Sequence[float], **kwargs: Any) -> Dataset: ... def quantile( - self, q: float | Sequence[float], *, axis: int | str = 0, **kwargs: Any + self, + q: float | Sequence[float], + *, + dim: int | str = 0, + axis: int | str | None = None, + **kwargs: Any, ) -> DataArray | Dataset: - """Compute the q-th quantile of the data along the specified axis. + """Compute the q-th quantile of the data along the specified dimension. Wrapping np.quantile @@ -1608,8 +1654,10 @@ def quantile( q: array_like of float Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. - axis: (int, str, None), optional - axis number or "time" or "space", by default 0 + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim **kwargs: Any Additional keyword arguments @@ -1622,14 +1670,15 @@ def quantile( -------- >>> da.quantile(q=[0.25,0.75]) >>> da.quantile(q=0.5) - >>> da.quantile(q=[0.01,0.5,0.99], axis="space") + >>> da.quantile(q=[0.01,0.5,0.99], dim="space") See Also -------- nanquantile : quantile with NaN values ignored """ - return self._quantile(q, axis=axis, func=np.quantile, **kwargs) + resolved = _resolve_deprecated_axis(dim, axis) + return self._quantile(q, axis=resolved, func=np.quantile, **kwargs) @overload def nanquantile(self, q: float, **kwargs: Any) -> DataArray: ... @@ -1638,9 +1687,14 @@ def nanquantile(self, q: float, **kwargs: Any) -> DataArray: ... def nanquantile(self, q: Sequence[float], **kwargs: Any) -> Dataset: ... def nanquantile( - self, q: float | Sequence[float], *, axis: int | str = 0, **kwargs: Any + self, + q: float | Sequence[float], + *, + dim: int | str = 0, + axis: int | str | None = None, + **kwargs: Any, ) -> DataArray | Dataset: - """Compute the q-th quantile of the data along the specified axis, while ignoring nan values. + """Compute the q-th quantile of the data along the specified dimension, ignoring NaN values. Wrapping np.nanquantile @@ -1649,8 +1703,10 @@ def nanquantile( q: array_like of float Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. - axis: (int, str, None), optional - axis number or "time" or "space", by default 0 + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim **kwargs: Any Additional keyword arguments @@ -1663,14 +1719,15 @@ def nanquantile( -------- >>> da.nanquantile(q=[0.25,0.75]) >>> da.nanquantile(q=0.5) - >>> da.nanquantile(q=[0.01,0.5,0.99], axis="space") + >>> da.nanquantile(q=[0.01,0.5,0.99], dim="space") See Also -------- quantile : Quantile with NaN values """ - return self._quantile(q, axis=axis, func=np.nanquantile, **kwargs) + resolved = _resolve_deprecated_axis(dim, axis) + return self._quantile(q, axis=resolved, func=np.nanquantile, **kwargs) def _quantile(self, q, *, axis: int | str = 0, func=np.quantile, **kwargs: Any): # type: ignore from mikeio import Dataset diff --git a/src/mikeio/dataset/_dataset.py b/src/mikeio/dataset/_dataset.py index 10289c003..8c8ea332c 100644 --- a/src/mikeio/dataset/_dataset.py +++ b/src/mikeio/dataset/_dataset.py @@ -26,7 +26,7 @@ import xarray import polars as pl -from ._dataarray import DataArray +from ._dataarray import DataArray, _resolve_deprecated_axis from .._track import _extract_track from ..eum import EUMType, EUMUnit, ItemInfo from ..spatial import ( @@ -969,10 +969,17 @@ def interp_time( return Dataset(das) - def interp_na(self, axis: str = "time", **kwargs: Any) -> Dataset: + def interp_na( + self, + dim: str = "time", + *, + axis: str | None = None, + **kwargs: Any, + ) -> Dataset: + resolved = _resolve_deprecated_axis(dim, axis) ds = self.copy() for da in ds: - da.values = da.interp_na(axis=axis, **kwargs).values + da.values = da.interp_na(dim=resolved, **kwargs).values return ds @@ -1188,18 +1195,22 @@ def _check_datasets_match(self, other: Dataset) -> None: # ============ aggregate ============= def aggregate( - self, axis: int | str | None = 0, func: Callable = np.nanmean, **kwargs: Any + self, + dim: int | str | None = 0, + func: Callable = np.nanmean, + *, + axis: int | str | None = None, ) -> Dataset: - """Aggregate along an axis. + """Aggregate along a dimension. Parameters ---------- - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 + dim: int or str, optional + dimension to aggregate over, 0 (time) by default func: function, optional default np.nanmean - **kwargs: Any - additional arguments passed to the function + axis: int or str, optional + deprecated, use dim instead Returns ------- @@ -1207,13 +1218,13 @@ def aggregate( dataset with aggregated values """ - if axis == "items": + dim = _resolve_deprecated_axis(dim, axis) + if dim == "items": if self.n_items <= 1: return self - name = kwargs.pop("name", func.__name__) - data = func(self.to_numpy(), axis=0, **kwargs) - item = self._agg_item_from_items(self.items, name) + data = func(self.to_numpy(), axis=0) + item = self._agg_item_from_items(self.items, func.__name__) da = DataArray( data=data, time=self.time, @@ -1225,7 +1236,7 @@ def aggregate( return Dataset([da], validate=False) else: res = { - name: da.aggregate(axis=axis, func=func, **kwargs) + name: da.aggregate(dim=dim, func=func) for name, da in self._data_vars.items() } return Dataset(data=res, validate=False) @@ -1245,9 +1256,14 @@ def _agg_item_from_items(items: Sequence[ItemInfo], name: str) -> ItemInfo: return ItemInfo(name, it_type, it_unit) def quantile( - self, q: float | Sequence[float], *, axis: int | str = 0, **kwargs: Any + self, + q: float | Sequence[float], + *, + dim: int | str = 0, + axis: int | str | None = None, + **kwargs: Any, ) -> Dataset: - """Compute the q-th quantile of the data along the specified axis. + """Compute the q-th quantile of the data along the specified dimension. Wrapping np.quantile @@ -1256,8 +1272,10 @@ def quantile( q: array_like of float Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim **kwargs: Any additional arguments passed to the function @@ -1270,19 +1288,25 @@ def quantile( -------- >>> ds.quantile(q=[0.25,0.75]) >>> ds.quantile(q=0.5) - >>> ds.quantile(q=[0.01,0.5,0.99], axis="space") + >>> ds.quantile(q=[0.01,0.5,0.99], dim="space") See Also -------- nanquantile : quantile with NaN values ignored """ - return self._quantile(q, axis=axis, func=np.quantile, **kwargs) + resolved = _resolve_deprecated_axis(dim, axis) + return self._quantile(q, axis=resolved, func=np.quantile, **kwargs) def nanquantile( - self, q: float | Sequence[float], *, axis: int | str = 0, **kwargs: Any + self, + q: float | Sequence[float], + *, + dim: int | str = 0, + axis: int | str | None = None, + **kwargs: Any, ) -> Dataset: - """Compute the q-th quantile of the data along the specified axis, while ignoring nan values. + """Compute the q-th quantile of the data along the specified dimension, ignoring NaN values. Wrapping np.nanquantile @@ -1291,8 +1315,10 @@ def nanquantile( q: array_like of float Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim **kwargs: Any additional arguments passed to the function @@ -1300,7 +1326,7 @@ def nanquantile( -------- >>> ds.nanquantile(q=[0.25,0.75]) >>> ds.nanquantile(q=0.5) - >>> ds.nanquantile(q=[0.01,0.5,0.99], axis="space") + >>> ds.nanquantile(q=[0.01,0.5,0.99], dim="space") Returns ------- @@ -1308,7 +1334,8 @@ def nanquantile( dataset with quantile values """ - return self._quantile(q, axis=axis, func=np.nanquantile, **kwargs) + resolved = _resolve_deprecated_axis(dim, axis) + return self._quantile(q, axis=resolved, func=np.nanquantile, **kwargs) def _quantile(self, q, *, axis=0, func=np.quantile, **kwargs) -> Dataset: # type: ignore if axis == "items": @@ -1348,15 +1375,15 @@ def _quantile(self, q, *, axis=0, func=np.quantile, **kwargs) -> Dataset: # typ return Dataset(data=res, validate=False) - def max(self, axis: int | str = 0, **kwargs: Any) -> Dataset: - """Max value along an axis. + def max(self, dim: int | str = 0, *, axis: int | str | None = None) -> Dataset: + """Max value along a dimension. Parameters ---------- - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 - **kwargs: Any - additional arguments passed to the function + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1368,17 +1395,17 @@ def max(self, axis: int | str = 0, **kwargs: Any) -> Dataset: nanmax : Max values with NaN values removed """ - return self.aggregate(axis=axis, func=np.max, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.max) - def min(self, axis: int | str = 0, **kwargs: Any) -> Dataset: - """Min value along an axis. + def min(self, dim: int | str = 0, *, axis: int | str | None = None) -> Dataset: + """Min value along a dimension. Parameters ---------- - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 - **kwargs: Any - additional arguments passed to the function + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1390,17 +1417,17 @@ def min(self, axis: int | str = 0, **kwargs: Any) -> Dataset: nanmin : Min values with NaN values removed """ - return self.aggregate(axis=axis, func=np.min, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.min) - def mean(self, axis: int | str = 0, **kwargs: Any) -> Dataset: - """Mean value along an axis. + def mean(self, dim: int | str = 0, *, axis: int | str | None = None) -> Dataset: + """Mean value along a dimension. Parameters ---------- - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 - **kwargs: Any - additional arguments passed to the function + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1413,17 +1440,17 @@ def mean(self, axis: int | str = 0, **kwargs: Any) -> Dataset: average : Weighted average """ - return self.aggregate(axis=axis, func=np.mean, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.mean) - def std(self, axis: int | str = 0, **kwargs: Any) -> Dataset: - """Standard deviation along an axis. + def std(self, dim: int | str = 0, *, axis: int | str | None = None) -> Dataset: + """Standard deviation along a dimension. Parameters ---------- - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 - **kwargs: Any - additional arguments passed to the function + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1435,14 +1462,17 @@ def std(self, axis: int | str = 0, **kwargs: Any) -> Dataset: nanstd : Standard deviation with NaN values removed """ - return self.aggregate(axis=axis, func=np.std, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.std) - def ptp(self, axis: int | str = 0, **kwargs: Any) -> Dataset: - """Range (max - min) a.k.a Peak to Peak along an axis - Parameters. + def ptp(self, dim: int | str = 0, *, axis: int | str | None = None) -> Dataset: + """Range (max - min) a.k.a Peak to Peak along a dimension. + + Parameters ---------- - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1450,10 +1480,12 @@ def ptp(self, axis: int | str = 0, **kwargs: Any) -> Dataset: dataset with peak to peak values """ - return self.aggregate(axis=axis, func=np.ptp, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.ptp) - def average(self, *, weights, axis=0, **kwargs) -> Dataset: # type: ignore - """Compute the weighted average along the specified axis. + def average( + self, *, weights: Any, dim: int | str = 0, axis: int | str | None = None + ) -> Dataset: + """Compute the weighted average along the specified dimension. Wraps [](`numpy.average`) @@ -1461,10 +1493,10 @@ def average(self, *, weights, axis=0, **kwargs) -> Dataset: # type: ignore ---------- weights: array_like weights to average over - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 - **kwargs: Any - additional arguments passed to the function + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1481,7 +1513,7 @@ def average(self, *, weights, axis=0, **kwargs) -> Dataset: # type: ignore >>> dfs = Dfsu("HD2D.dfsu") >>> ds = dfs.read(["Current speed"]) >>> area = dfs.get_element_area() - >>> ds2 = ds.average(axis="space", weights=area) + >>> ds2 = ds.average(dim="space", weights=area) """ @@ -1491,21 +1523,23 @@ def func(x, axis, keepdims): # type: ignore return np.average(x, weights=weights, axis=axis) - return self.aggregate(axis=axis, func=func, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=func) - def nanmax(self, axis: int | str | None = 0, **kwargs: Any) -> Dataset: - """Max value along an axis (NaN removed). + def nanmax( + self, dim: int | str | None = 0, *, axis: int | str | None = None + ) -> Dataset: + """Max value along a dimension (NaN removed). Parameters ---------- - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 - **kwargs: Any - additional arguments passed to the function + dim: int, str or None, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim See Also -------- - max : Mean values + max : Max values Returns ------- @@ -1513,17 +1547,19 @@ def nanmax(self, axis: int | str | None = 0, **kwargs: Any) -> Dataset: dataset with max values """ - return self.aggregate(axis=axis, func=np.nanmax, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.nanmax) - def nanmin(self, axis: int | str | None = 0, **kwargs: Any) -> Dataset: - """Min value along an axis (NaN removed). + def nanmin( + self, dim: int | str | None = 0, *, axis: int | str | None = None + ) -> Dataset: + """Min value along a dimension (NaN removed). Parameters ---------- - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 - **kwargs: Any - additional arguments passed to the function + dim: int, str or None, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1531,17 +1567,17 @@ def nanmin(self, axis: int | str | None = 0, **kwargs: Any) -> Dataset: dataset with min values """ - return self.aggregate(axis=axis, func=np.nanmin, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.nanmin) - def nanmean(self, axis: int | str = 0, **kwargs: Any) -> Dataset: - """Mean value along an axis (NaN removed). + def nanmean(self, dim: int | str = 0, *, axis: int | str | None = None) -> Dataset: + """Mean value along a dimension (NaN removed). Parameters ---------- - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 - **kwargs: Any - additional arguments passed to the function + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1549,17 +1585,17 @@ def nanmean(self, axis: int | str = 0, **kwargs: Any) -> Dataset: dataset with mean values """ - return self.aggregate(axis=axis, func=np.nanmean, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.nanmean) - def nanstd(self, axis: int | str = 0, **kwargs: Any) -> Dataset: - """Standard deviation along an axis (NaN removed). + def nanstd(self, dim: int | str = 0, *, axis: int | str | None = None) -> Dataset: + """Standard deviation along a dimension (NaN removed). Parameters ---------- - axis: (int, str, None), optional - axis number or "time", "space" or "items", by default 0 - **kwargs: Any - additional arguments passed to the function + dim: int or str, optional + dimension, by default 0 (time) + axis: int or str, optional + deprecated, use dim Returns ------- @@ -1571,7 +1607,7 @@ def nanstd(self, axis: int | str = 0, **kwargs: Any) -> Dataset: std : Standard deviation """ - return self.aggregate(axis=axis, func=np.nanstd, **kwargs) + return self.aggregate(dim=dim, axis=axis, func=np.nanstd) # ============ arithmetic/Math ============= diff --git a/tests/test_dataarray.py b/tests/test_dataarray.py index 7c91f9f9c..2983af2c8 100644 --- a/tests/test_dataarray.py +++ b/tests/test_dataarray.py @@ -1,5 +1,6 @@ from datetime import datetime from pathlib import Path +import warnings import numpy as np import pandas as pd import matplotlib.pyplot as plt @@ -957,10 +958,10 @@ def test_daarray_aggregation_dfs2() -> None: assert da.shape == (1, 264, 216) - dam = da.nanmean(axis=None) + dam = da.nanmean(None) assert np.isscalar(dam.values) # TODO is this what we want - dasm = da.nanmean(axis="space") + dasm = da.nanmean("space") assert dasm.shape == (1,) @@ -972,7 +973,7 @@ def test_dataarray_weigthed_average() -> None: area = da.geometry.get_element_area() - da2 = da.average(weights=area, axis=1) + da2 = da.average(weights=area, dim=1) assert isinstance(da2.geometry, mikeio.spatial.Geometry0D) assert da2.dims == ("time",) @@ -1010,7 +1011,8 @@ def test_daarray_aggregation() -> None: assert pytest.approx(da_mean.values[0]) == 0.04334851 assert pytest.approx(da_mean.values[778]) == 0.452692 - da_std = da.std(name="standard deviation") + da_std = da.std() + da_std.name = "standard deviation" assert isinstance(da_std, mikeio.DataArray) assert da_std.name == "standard deviation" assert da_std.geometry == da.geometry @@ -1018,7 +1020,8 @@ def test_daarray_aggregation() -> None: assert len(da_std.time) == 1 assert pytest.approx(da_std.values[0]) == 0.015291579 - da_ptp = da.ptp(name="peak to peak (max - min)") + da_ptp = da.ptp() + da_ptp.name = "peak to peak (max - min)" assert isinstance(da_std, mikeio.DataArray) assert da_ptp.geometry == da.geometry assert da_ptp.start_time == da.start_time @@ -1076,7 +1079,7 @@ def test_daarray_aggregation_nan_versions() -> None: def test_da_quantile_axis0(da2: DataArray) -> None: assert da2.geometry.nx == 7 assert len(da2.time) == 10 - daq = da2.quantile(q=0.345, axis="time") + daq = da2.quantile(q=0.345, dim="time") assert daq.geometry.nx == 7 assert len(da2.time) == 10 # this should not change assert len(daq.time) == 1 # aggregated @@ -1086,7 +1089,7 @@ def test_da_quantile_axis0(da2: DataArray) -> None: assert daq.dims[0] == "x" assert daq.n_timesteps == 1 - daqs = da2.quantile(q=0.345, axis="space") + daqs = da2.quantile(q=0.345, dim="space") assert isinstance( daqs.geometry, mikeio.spatial.Geometry0D ) # Aggregating over space returns Geometry0D (time series) @@ -1096,7 +1099,7 @@ def test_da_quantile_axis0(da2: DataArray) -> None: assert daqs.dims[0][0] == "t" # Because it's a mikeio.Grid1D, remember! # q as list - daq = da2.quantile(q=[0.25, 0.75], axis=0) + daq = da2.quantile(q=[0.25, 0.75], dim=0) assert isinstance(daq, mikeio.Dataset) assert daq.n_items == 2 assert daq[0].to_numpy()[0] == 0.1 @@ -1409,7 +1412,7 @@ def test_parse_time_decreasing() -> None: def test_geometry0d_space_axis_raises() -> None: - """Test that Geometry0D raises ValueError for axis='space'.""" + """Test that Geometry0D raises ValueError for dim='space'.""" from mikeio.spatial import Geometry0D da = mikeio.DataArray( @@ -1418,17 +1421,17 @@ def test_geometry0d_space_axis_raises() -> None: geometry=Geometry0D(), ) with pytest.raises(ValueError, match="space axis cannot be selected"): - da.mean(axis="space") + da.mean("space") def test_point_spectrum_space_axis_raises() -> None: - """Test that point spectrum raises ValueError for axis='space'.""" + """Test that point spectrum raises ValueError for dim='space'.""" dfs = mikeio.open("tests/testdata/spectra/pt_spectra.dfsu") assert isinstance(dfs, DfsuSpectral) ds = dfs.read() da = ds[0] with pytest.raises(ValueError, match="space axis cannot be selected"): - da.mean(axis="space") + da.mean("space") def test_area_spectrum_space_axis() -> None: @@ -1438,7 +1441,7 @@ def test_area_spectrum_space_axis() -> None: ds = dfs.read() da = ds[0] # Should aggregate over element axis only - result = da.mean(axis="space") + result = da.mean("space") # Result should have time, direction, frequency but not element assert "element" not in result.dims assert result.shape == (3, 16, 25) # time, direction, frequency @@ -1451,18 +1454,18 @@ def test_line_spectrum_space_axis() -> None: ds = dfs.read() da = ds[0] # Should aggregate over node axis only - result = da.mean(axis="space") + result = da.mean("space") # Result should have time, direction, frequency but not node assert "node" not in result.dims assert result.shape == (4, 16, 25) # time, direction, frequency def test_parse_axis_none_default() -> None: - """Test that axis=None defaults to all axes.""" + """Test that dim=None defaults to all axes.""" ds = mikeio.read("tests/testdata/waves.dfs2") da = ds[0] - # axis=None should aggregate over all axes - result = da.mean(axis=None) + # dim=None should aggregate over all axes + result = da.mean(None) assert result.ndim == 0 # Scalar result @@ -1474,16 +1477,84 @@ def test_grid2d_space_axis_with_time() -> None: # Grid2D has dims ("y", "x") so space axis should be (0, 1) assert da.geometry.get_space_axis() == (0, 1) # With time, should aggregate over axes (1, 2) - y and x - result = da.mean(axis="space") + result = da.mean("space") assert result.shape == (3,) # Only time dimension left assert result.dims == ("time",) def test_axis_spatial_deprecated() -> None: - """Test that axis='spatial' emits FutureWarning and works like 'space'.""" + """Test that dim='spatial' emits FutureWarning and works like 'space'.""" ds = mikeio.read("tests/testdata/waves.dfs2") da = ds[0] with pytest.warns(FutureWarning, match="axis='spatial' is deprecated"): - result = da.mean(axis="spatial") + result = da.mean("spatial") assert result.shape == (3,) assert result.dims == ("time",) + + +def test_dim_keyword_mean(da2: mikeio.DataArray) -> None: + result = da2.mean(dim="time") + assert result.n_timesteps == 1 + assert result.geometry.nx == 7 + + result_space = da2.mean(dim="space") + assert len(result_space.time) == 10 + + +def test_dim_keyword_max_min(da2: mikeio.DataArray) -> None: + assert da2.max(dim="time").n_timesteps == 1 + assert da2.min(dim="space").ndim == 1 + + +def test_dim_keyword_std(da2: mikeio.DataArray) -> None: + result = da2.std(dim="time") + assert result.n_timesteps == 1 + + +def test_dim_keyword_nanmean(da2: mikeio.DataArray) -> None: + result = da2.nanmean(dim="time") + assert result.n_timesteps == 1 + + +def test_dim_keyword_nanmax_nanmin(da2: mikeio.DataArray) -> None: + assert da2.nanmax(dim="time").n_timesteps == 1 + assert da2.nanmin(dim="space").ndim == 1 + + +def test_dim_keyword_nanstd(da2: mikeio.DataArray) -> None: + result = da2.nanstd(dim="time") + assert result.n_timesteps == 1 + + +def test_dim_keyword_quantile(da2: mikeio.DataArray) -> None: + result = da2.quantile(q=0.5, dim="time") + assert result.n_timesteps == 1 + + result_space = da2.nanquantile(q=0.5, dim="space") + assert len(result_space.time) == 10 + + +def test_dim_keyword_aggregate(da2: mikeio.DataArray) -> None: + result = da2.aggregate(dim="time") + assert result.n_timesteps == 1 + + +def test_dim_keyword_ptp(da2: mikeio.DataArray) -> None: + result = da2.ptp(dim="time") + assert result.n_timesteps == 1 + + +def test_axis_keyword_deprecation_warning(da2: mikeio.DataArray) -> None: + with pytest.warns(FutureWarning, match="'axis' keyword is deprecated"): + da2.mean(axis="time") + + +def test_axis_keyword_int_deprecation_warning(da2: mikeio.DataArray) -> None: + with pytest.warns(FutureWarning, match="'axis' keyword is deprecated"): + da2.mean(axis=0) + + +def test_positional_string_no_warning(da2: mikeio.DataArray) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("error", FutureWarning) + da2.mean("space") # positional string should not warn diff --git a/tests/test_dataset.py b/tests/test_dataset.py index f9133a2cb..1e7bebece 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,5 +1,6 @@ from pathlib import Path from datetime import datetime +import warnings import numpy as np import pandas as pd import pytest @@ -598,20 +599,20 @@ def test_aggregation_dataset_no_time() -> None: def test_aggregations() -> None: ds = mikeio.read("tests/testdata/gebco_sound.dfs2") - for axis in [0, 1, 2]: - ds.mean(axis=axis) - ds.nanmean(axis=axis) - ds.nanmin(axis=axis) - ds.nanmax(axis=axis) + for dim in [0, 1, 2]: + ds.mean(dim) + ds.nanmean(dim) + ds.nanmin(dim) + ds.nanmax(dim) assert ds.mean().shape == (264, 216) assert ds.mean("time").shape == (264, 216) assert ds.mean("space").shape == (1,) with pytest.raises(ValueError, match="space"): - ds.mean(axis="spaghetti") + ds.mean("spaghetti") - dsm = ds.mean(axis="time") + dsm = ds.mean("time") assert dsm.geometry is not None @@ -628,7 +629,7 @@ def test_to_dfs_extension_validation(tmp_path: Path) -> None: def test_quantile_axis1(ds1: Dataset) -> None: - dsq = ds1.quantile(q=0.345, axis=1) + dsq = ds1.quantile(q=0.345, dim=1) assert dsq[0].to_numpy()[0] == 0.1 assert dsq[1].to_numpy()[0] == 0.2 @@ -636,7 +637,7 @@ def test_quantile_axis1(ds1: Dataset) -> None: assert dsq.n_timesteps == ds1.n_timesteps # q as list - dsq = ds1.quantile(q=[0.25, 0.75], axis=1) + dsq = ds1.quantile(q=[0.25, 0.75], dim=1) assert dsq.n_items == 2 * ds1.n_items assert "Quantile 0.75, " in dsq.items[1].name assert "Quantile 0.25, " in dsq.items[2].name @@ -652,7 +653,7 @@ def test_quantile_axis0(ds1: Dataset) -> None: assert dsq.shape[-1] == ds1.shape[-1] # q as list - dsq = ds1.quantile(q=[0.25, 0.75], axis=0) + dsq = ds1.quantile(q=[0.25, 0.75], dim=0) assert dsq.n_items == 2 * ds1.n_items assert dsq[0].to_numpy()[0] == 0.1 assert dsq[1].to_numpy()[0] == 0.1 @@ -685,13 +686,13 @@ def test_nanquantile() -> None: def test_aggregate_across_items() -> None: ds = mikeio.read("tests/testdata/State_wlbc_north_err.dfs1") - dsm = ds.mean(axis="items") + dsm = ds.mean("items") assert isinstance(dsm, mikeio.Dataset) assert dsm.geometry == ds.geometry assert dsm.dims == ds.dims - dsq = ds.quantile(q=[0.1, 0.5, 0.9], axis="items") + dsq = ds.quantile(q=[0.1, 0.5, 0.9], dim="items") assert isinstance(dsq, mikeio.Dataset) assert dsq[0].name == "Quantile 0.1" assert dsq[1].name == "Quantile 0.5" @@ -705,8 +706,9 @@ def test_aggregate_selected_items_dfsu_save_to_new_file(tmp_path: Path) -> None: assert ds.n_items == 5 - dsm = ds.max(axis="items", name="Max Water Level") # add a nice name + dsm = ds.max(dim="items") assert len(dsm) == 1 + dsm[0].name = "Max Water Level" assert dsm[0].name == "Max Water Level" assert dsm.geometry == ds.geometry assert dsm.dims == ds.dims @@ -772,7 +774,7 @@ def test_dfsu3d_dataset() -> None: assert len(ds) == 2 # Salinity, Temperature - dsagg = ds.nanmean(axis=0) # Time averaged + dsagg = ds.nanmean(0) # Time averaged assert len(dsagg) == 2 # Salinity, Temperature @@ -1427,3 +1429,76 @@ def test_safe_name() -> None: bad_name = "MSLP., 1:st level\n 2nd chain" safe_name = "MSLP_1_st_level_2nd_chain" assert _to_safe_name(bad_name) == safe_name + + +def test_dim_keyword_mean(ds1: Dataset) -> None: + result = ds1.mean(dim="time") + assert result.n_timesteps == 1 + assert result.n_items == 2 + + result_space = ds1.mean(dim="space") + assert len(result_space.time) == 10 + + +def test_dim_keyword_max_min(ds1: Dataset) -> None: + assert ds1.max(dim="time").n_timesteps == 1 + assert ds1.min(dim="space")[0].ndim == 1 + + +def test_dim_keyword_std(ds1: Dataset) -> None: + result = ds1.std(dim="time") + assert result.n_timesteps == 1 + + +def test_dim_keyword_nanmean(ds1: Dataset) -> None: + result = ds1.nanmean(dim="time") + assert result.n_timesteps == 1 + + +def test_dim_keyword_nanmax_nanmin(ds1: Dataset) -> None: + assert ds1.nanmax(dim="time").n_timesteps == 1 + assert ds1.nanmin(dim="space")[0].ndim == 1 + + +def test_dim_keyword_nanstd(ds1: Dataset) -> None: + result = ds1.nanstd(dim="time") + assert result.n_timesteps == 1 + + +def test_dim_keyword_quantile(ds1: Dataset) -> None: + result = ds1.quantile(q=0.5, dim="time") + assert result.n_timesteps == 1 + + result_space = ds1.nanquantile(q=0.5, dim="space") + assert len(result_space.time) == 10 + + +def test_dim_keyword_aggregate(ds1: Dataset) -> None: + result = ds1.aggregate(dim="time") + assert result.n_timesteps == 1 + + +def test_dim_keyword_items(ds1: Dataset) -> None: + result = ds1.mean(dim="items") + assert result.n_items == 1 + + +def test_dim_keyword_ptp(ds1: Dataset) -> None: + result = ds1.ptp(dim="time") + assert result.n_timesteps == 1 + + +def test_axis_keyword_deprecation_warning_dataset(ds1: Dataset) -> None: + with pytest.warns(FutureWarning, match="'axis' keyword is deprecated"): + ds1.mean(axis="time") + + +def test_axis_keyword_int_deprecation_warning_dataset(ds1: Dataset) -> None: + with pytest.warns(FutureWarning, match="'axis' keyword is deprecated"): + ds1.mean(axis=0) + + +def test_positional_string_no_warning_dataset(ds1: Dataset) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("error", FutureWarning) + ds1.mean("space") # positional string should not warn diff --git a/tests/test_dfs2.py b/tests/test_dfs2.py index bec7cda59..d3567c93a 100644 --- a/tests/test_dfs2.py +++ b/tests/test_dfs2.py @@ -677,7 +677,7 @@ def test_spatial_aggregation_dfs2_to_dfs0(tmp_path: Path) -> None: outfilename = tmp_path / "waves_max.dfs0" ds = mikeio.read("tests/testdata/waves.dfs2") - ds_max = ds.nanmax(axis="space") + ds_max = ds.nanmax("space") ds_max.to_dfs(outfilename) dsnew = mikeio.read(outfilename) diff --git a/tests/test_dfs3.py b/tests/test_dfs3.py index 960459ac5..b63685b77 100644 --- a/tests/test_dfs3.py +++ b/tests/test_dfs3.py @@ -124,8 +124,8 @@ def test_read_top_layer() -> None: dssel = dsall.isel(z=-1) assert dssel.geometry == ds.geometry dsdiff = dssel - ds - assert dsdiff.nanmax(axis=None).to_numpy()[0] == 0.0 - assert dsdiff.nanmin(axis=None).to_numpy()[0] == 0.0 + assert dsdiff.nanmax(None).to_numpy()[0] == 0.0 + assert dsdiff.nanmin(None).to_numpy()[0] == 0.0 def test_read_bottom_layer() -> None: diff --git a/tests/test_generic.py b/tests/test_generic.py index b149a2ca8..008464848 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -543,7 +543,7 @@ def test_time_average(tmp_path: Path) -> None: assert org.time[0] == averaged.time[0] assert org.shape[1] == averaged.shape[1] assert averaged.shape[0] == 1 - assert np.allclose(org.mean(axis=0)[0].to_numpy(), averaged[0].to_numpy()) + assert np.allclose(org.mean(0)[0].to_numpy(), averaged[0].to_numpy()) def test_time_average_dfsu_3d(tmp_path: Path) -> None: @@ -581,7 +581,7 @@ def test_quantile_dfsu(tmp_path: Path) -> None: fp = tmp_path / "oresund_q10.dfsu" generic.quantile(infilename, fp, q=0.1, items=["Surface elevation"]) - org = mikeio.read(infilename).quantile(q=0.1, axis=0) + org = mikeio.read(infilename).quantile(q=0.1, dim=0) q10 = mikeio.read(fp) assert np.allclose(org[0].to_numpy(), q10[0].to_numpy()) @@ -592,7 +592,7 @@ def test_quantile_dfsu_buffer_size(tmp_path: Path) -> None: fp = tmp_path / "oresund_q10.dfsu" generic.quantile(infilename, fp, q=0.1, buffer_size=1e5, items=0) - org = mikeio.read(infilename).quantile(q=0.1, axis=0) + org = mikeio.read(infilename).quantile(q=0.1, dim=0) q10 = mikeio.read(fp) assert np.allclose(org[0].to_numpy(), q10[0].to_numpy()) @@ -603,7 +603,7 @@ def test_quantile_dfs2(tmp_path: Path) -> None: fp = tmp_path / "eq_q90.dfs2" generic.quantile(infilename, fp, q=0.9) - org = mikeio.read(infilename).quantile(q=0.9, axis=0) + org = mikeio.read(infilename).quantile(q=0.9, dim=0) q90 = mikeio.read(fp) assert np.allclose(org[0].to_numpy(), q90[0].to_numpy()) @@ -614,7 +614,7 @@ def test_quantile_dfs0(tmp_path: Path) -> None: fp = tmp_path / "da_q001_q05.dfs0" generic.quantile(infilename, fp, q=[0.01, 0.5]) - org = mikeio.read(infilename).quantile(q=[0.01, 0.5], axis=0) + org = mikeio.read(infilename).quantile(q=[0.01, 0.5], dim=0) qnt = mikeio.read(fp) assert np.allclose(org[0].to_numpy(), qnt[0].to_numpy()) diff --git a/tests/test_integration.py b/tests/test_integration.py index 0abf0e440..9894eddf7 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -8,7 +8,7 @@ def test_read_file_multiply_two_items_and_save_to_new_file(tmp_path: Path) -> None: ds = mikeio.read("tests/testdata/oresundHD_run1.dfsu") - da = (ds[0] * ds[1]).nanmax(axis="time") + da = (ds[0] * ds[1]).nanmax("time") file_path = tmp_path / "mult.dfsu" @@ -19,13 +19,13 @@ def test_aggregation_workflows(tmp_path: Path) -> None: dfs = mikeio.Dfsu2DH("tests/testdata/HD2D.dfsu") ds = dfs.read(items=["Surface elevation", "Current speed"]) - ds2 = ds.max(axis=1) + ds2 = ds.max(1) outfilename = tmp_path / "max.dfs0" ds2.to_dfs(outfilename) assert outfilename.exists() - ds3 = ds.min(axis=1) + ds3 = ds.min(1) outfilename = tmp_path / "min.dfs0" ds3.to_dfs(outfilename) @@ -38,7 +38,7 @@ def test_weighted_average(tmp_path: Path) -> None: ds = dfs.read(items=["Surface elevation", "Current speed"]) area = dfs.geometry.get_element_area() - ds2 = ds.average(weights=area, axis=1) + ds2 = ds.average(weights=area, dim=1) out_path = tmp_path / "average.dfs0" ds2.to_dfs(out_path) From d734ef507e420a8da9fde5ee34e7df1f01c02c50 Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Mon, 16 Mar 2026 12:40:56 +0100 Subject: [PATCH 2/2] Add typing overloads to _resolve_deprecated_axis for mypy --- src/mikeio/dataset/_dataarray.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/mikeio/dataset/_dataarray.py b/src/mikeio/dataset/_dataarray.py index 0fefa0d1b..0b624232c 100644 --- a/src/mikeio/dataset/_dataarray.py +++ b/src/mikeio/dataset/_dataarray.py @@ -69,6 +69,27 @@ ) +@overload +def _resolve_deprecated_axis( + dim: str, + axis: str | None, +) -> str: ... + + +@overload +def _resolve_deprecated_axis( + dim: int | str, + axis: int | str | None, +) -> int | str: ... + + +@overload +def _resolve_deprecated_axis( + dim: int | str | None, + axis: int | str | None, +) -> int | str | None: ... + + def _resolve_deprecated_axis( dim: int | str | None, axis: int | str | None,