Source code for xscen.spatial

"""Spatial tools."""

import datetime
import itertools
import warnings
from copy import deepcopy
from pathlib import Path
from typing import Optional

import clisops.core.subset
import dask
import numpy as np
import sparse as sp
import xarray as xr
import xclim as xc
from xclim.core.utils import uses_dask

from .config import parse_config

__all__ = [
    "creep_fill",
    "creep_weights",
    "subset",
]


[docs] @parse_config def creep_weights(mask: xr.DataArray, n: int = 1, mode: str = "clip") -> xr.DataArray: """Compute weights for the creep fill. The output is a sparse matrix with the same dimensions as `mask`, twice. Parameters ---------- mask : DataArray A boolean DataArray. False values are candidates to the filling. Usually they represent missing values (`mask = da.notnull()`). All dimensions are creep filled. n : int The order of neighbouring to use. 1 means only the adjacent grid cells are used. mode : {'clip', 'wrap'} If a cell is on the edge of the domain, `mode='wrap'` will wrap around to find neighbours. Returns ------- DataArray Weights. The dot product must be taken over the last N dimensions. """ da = mask mask = da.values neighbors = np.array( list(itertools.product(*[np.arange(-n, n + 1) for j in range(mask.ndim)])) ).T src = [] dst = [] w = [] it = np.nditer(mask, flags=["f_index", "multi_index"], order="C") for i in it: if not i: neigh_idx_2d = np.atleast_2d(it.multi_index).T + neighbors neigh_idx_1d = np.ravel_multi_index( neigh_idx_2d, mask.shape, order="C", mode=mode ) neigh_idx = np.unravel_index(np.unique(neigh_idx_1d), mask.shape, order="C") neigh = mask[neigh_idx] N = (neigh).sum() if N > 0: src.extend([it.multi_index] * N) dst.extend(np.stack(neigh_idx)[:, neigh].T) w.extend([1 / N] * N) else: src.extend([it.multi_index]) dst.extend([it.multi_index]) w.extend([np.nan]) else: src.extend([it.multi_index]) dst.extend([it.multi_index]) w.extend([1]) crds = np.concatenate((np.array(src).T, np.array(dst).T), axis=0) return xr.DataArray( sp.COO(crds, w, (*da.shape, *da.shape)), dims=[f"{d}_out" for d in da.dims] + list(da.dims), coords=da.coords, name="creep_fill_weights", )
[docs] @parse_config def creep_fill(da: xr.DataArray, w: xr.DataArray) -> xr.DataArray: """Creep fill using pre-computed weights. Parameters ---------- da: DataArray A DataArray sharing the dimensions with the one used to compute the weights. It can have other dimensions. Dask is supported as long as there are no chunks over the creeped dims. w: DataArray The result of `creep_weights`. Returns ------- xarray.DataArray, same shape as `da`, but values filled according to `w`. Examples -------- >>> w = creep_weights(da.isel(time=0).notnull(), n=1) >>> da_filled = creep_fill(da, w) """ def _dot(arr, wei): N = wei.ndim // 2 extra_dim = arr.ndim - N return np.tensordot(arr, wei, axes=(np.arange(N) + extra_dim, np.arange(N) + N)) N = w.ndim // 2 return xr.apply_ufunc( _dot, da, w, input_core_dims=[w.dims[N:], w.dims], output_core_dims=[w.dims[N:]], dask="parallelized", output_dtypes=["float64"], )
[docs] def subset( # noqa: C901 ds: xr.Dataset, region: Optional[dict] = None, *, name: Optional[str] = None, method: Optional[ str ] = None, # FIXME: Once the region argument is removed, this should be made mandatory. tile_buffer: float = 0, **kwargs, ) -> xr.Dataset: """ Subset the data to a region. Either creates a slice and uses the .sel() method, or customizes a call to clisops.subset() that allows for an automatic buffer around the region. Parameters ---------- ds : xr.Dataset Dataset to be subsetted. region: dict Deprecated argument that is there for legacy reasons and will be abandoned eventually. name: str, optional Used to rename the 'cat:domain' attribute. method : str ['gridpoint', 'bbox', shape','sel'] If the method is `sel`, this is not a call to clisops but only a subsetting with the xarray .sel() fonction. tile_buffer : float For ['bbox', shape'], uses an approximation of the grid cell size to add a buffer around the requested region. This differs from clisops' 'buffer' argument in subset_shape(). kwargs : dict Arguments to be sent to clisops. If the method is `sel`, the keys are the dimensions to subset and the values are turned into a slice. Returns ------- xr.Dataset Subsetted Dataset. See Also -------- clisops.core.subset.subset_gridpoint, clisops.core.subset.subset_bbox, clisops.core.subset.subset_shape """ if region is not None: warnings.warn( "The argument 'region' has been deprecated and will be abandoned in a future release.", category=FutureWarning, ) method = method or region.get("method") if ("buffer" in region) and ("shape" in region): warnings.warn( "To avoid confusion with clisops' buffer argument, xscen's 'buffer' has been renamed 'tile_buffer'.", category=FutureWarning, ) tile_buffer = tile_buffer or region.get("buffer", 0) else: tile_buffer = tile_buffer or region.get("tile_buffer", 0) kwargs = deepcopy(region[region["method"]]) if uses_dask(ds.lon) or uses_dask(ds.lat): warnings.warn("Loading longitude and latitude for more efficient subsetting.") ds["lon"], ds["lat"] = dask.compute(ds.lon, ds.lat) if tile_buffer > 0: if method not in ["bbox", "shape"]: warnings.warn( "tile_buffer has been specified, but is not used for the requested subsetting method.", ) # estimate the model resolution if len(ds.lon.dims) == 1: # 1D lat-lon lon_res = np.abs(ds.lon.diff("lon")[0].values) lat_res = np.abs(ds.lat.diff("lat")[0].values) else: lon_res = np.abs(ds.lon[0, 0].values - ds.lon[0, 1].values) lat_res = np.abs(ds.lat[0, 0].values - ds.lat[1, 0].values) if method in ["gridpoint"]: ds_subset = clisops.core.subset_gridpoint(ds, **kwargs) new_history = ( f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"{method} spatial subsetting on {len(kwargs['lon'])} coordinates - clisops v{clisops.__version__}" ) elif method in ["bbox"]: if tile_buffer > 0: # adjust the boundaries kwargs["lon_bnds"] = ( kwargs["lon_bnds"][0] - lon_res * tile_buffer, kwargs["lon_bnds"][1] + lon_res * tile_buffer, ) kwargs["lat_bnds"] = ( kwargs["lat_bnds"][0] - lat_res * tile_buffer, kwargs["lat_bnds"][1] + lat_res * tile_buffer, ) if xc.core.utils.uses_dask(ds.cf["longitude"]): ds[ds.cf["longitude"].name].load() if xc.core.utils.uses_dask(ds.cf["latitude"]): ds[ds.cf["latitude"].name].load() ds_subset = clisops.core.subset_bbox(ds, **kwargs) new_history = ( f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"{method} spatial subsetting with {'buffer=' + str(tile_buffer) if tile_buffer > 0 else 'no buffer'}" f", lon_bnds={np.array(kwargs['lon_bnds'])}, lat_bnds={np.array(kwargs['lat_bnds'])}" f" - clisops v{clisops.__version__}" ) elif method in ["shape"]: if tile_buffer > 0: if kwargs.get("buffer") is not None: raise NotImplementedError( "Both tile_buffer and clisops' buffer were requested. Use only one." ) kwargs["buffer"] = np.max([lon_res, lat_res]) * tile_buffer ds_subset = clisops.core.subset_shape(ds, **kwargs) new_history = ( f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"{method} spatial subsetting with {'buffer=' + str(tile_buffer) if tile_buffer > 0 else 'no buffer'}" f", shape={Path(kwargs['shape']).name if isinstance(kwargs['shape'], (str, Path)) else 'gpd.GeoDataFrame'}" f" - clisops v{clisops.__version__}" ) elif method in ["sel"]: arg_sel = {dim: slice(*map(float, bounds)) for dim, bounds in kwargs.items()} ds_subset = ds.sel(**arg_sel) new_history = ( f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"{method} subsetting with arguments {arg_sel}" ) else: raise ValueError("Subsetting type not recognized") history = ( new_history + " \n " + ds_subset.attrs["history"] if "history" in ds_subset.attrs else new_history ) ds_subset.attrs["history"] = history if name is not None: ds_subset.attrs["cat:domain"] = name return ds_subset