Source code for xscen.spatial

"""Spatial tools."""

import datetime
import itertools
import logging
import warnings
from collections.abc import Sequence
from pathlib import Path
from typing import Optional, Union

import clisops.core.subset
import dask
import geopandas as gpd
import numpy as np
import sparse as sp
import xarray as xr
import xclim as xc

from .config import parse_config

logger = logging.getLogger(__name__)

__all__ = [
    "creep_fill",
    "creep_weights",
    "subset",
]


[docs] @parse_config def creep_weights(mask: xr.DataArray, n: int = 1, mode: str = "clip") -> xr.DataArray: """Compute weights for the creep fill. The output is a sparse matrix with the same dimensions as `mask`, twice. Parameters ---------- mask : DataArray A boolean DataArray. False values are candidates to the filling. Usually they represent missing values (`mask = da.notnull()`). All dimensions are creep filled. n : int The order of neighbouring to use. 1 means only the adjacent grid cells are used. mode : {'clip', 'wrap'} If a cell is on the edge of the domain, `mode='wrap'` will wrap around to find neighbours. Returns ------- DataArray Weights. The dot product must be taken over the last N dimensions. """ da = mask mask = da.values neighbors = np.array( list(itertools.product(*[np.arange(-n, n + 1) for j in range(mask.ndim)])) ).T src = [] dst = [] w = [] it = np.nditer(mask, flags=["f_index", "multi_index"], order="C") for i in it: if not i: neigh_idx_2d = np.atleast_2d(it.multi_index).T + neighbors neigh_idx_1d = np.ravel_multi_index( neigh_idx_2d, mask.shape, order="C", mode=mode ) if mode == "clip": neigh_idx = np.unravel_index( np.unique(neigh_idx_1d), mask.shape, order="C" ) elif mode == "wrap": neigh_idx = np.unravel_index(neigh_idx_1d, mask.shape, order="C") else: raise ValueError("mode must be either 'clip' or 'wrap'") neigh = mask[neigh_idx] N = (neigh).sum() if N > 0: src.extend([it.multi_index] * N) dst.extend(np.stack(neigh_idx)[:, neigh].T) w.extend([1 / N] * N) else: src.extend([it.multi_index]) dst.extend([it.multi_index]) w.extend([np.nan]) else: src.extend([it.multi_index]) dst.extend([it.multi_index]) w.extend([1]) crds = np.concatenate((np.array(src).T, np.array(dst).T), axis=0) return xr.DataArray( sp.COO(crds, w, (*da.shape, *da.shape)), dims=[f"{d}_out" for d in da.dims] + list(da.dims), coords=da.coords, name="creep_fill_weights", )
[docs] @parse_config def creep_fill(da: xr.DataArray, w: xr.DataArray) -> xr.DataArray: """Creep fill using pre-computed weights. Parameters ---------- da: DataArray A DataArray sharing the dimensions with the one used to compute the weights. It can have other dimensions. Dask is supported as long as there are no chunks over the creeped dims. w: DataArray The result of `creep_weights`. Returns ------- xarray.DataArray, same shape as `da`, but values filled according to `w`. Examples -------- >>> w = creep_weights(da.isel(time=0).notnull(), n=1) >>> da_filled = creep_fill(da, w) """ def _dot(arr, wei): N = wei.ndim // 2 extra_dim = arr.ndim - N return np.tensordot(arr, wei, axes=(np.arange(N) + extra_dim, np.arange(N) + N)) N = w.ndim // 2 return xr.apply_ufunc( _dot, da, w, input_core_dims=[w.dims[N:], w.dims], output_core_dims=[w.dims[N:]], dask="parallelized", output_dtypes=["float64"], )
[docs] def subset( ds: xr.Dataset, method: str, *, name: Optional[str] = None, tile_buffer: float = 0, **kwargs, ) -> xr.Dataset: r""" Subset the data to a region. Either creates a slice and uses the .sel() method, or customizes a call to clisops.subset() that allows for an automatic buffer around the region. Parameters ---------- ds : xr.Dataset Dataset to be subsetted. method : str ['gridpoint', 'bbox', shape', 'sel'] If the method is `sel`, this is not a call to clisops but only a subsetting with the xarray .sel() fonction. name: str, optional Used to rename the 'cat:domain' attribute. tile_buffer : float For ['bbox', shape'], uses an approximation of the grid cell size to add a buffer around the requested region. This differs from clisops' 'buffer' argument in subset_shape(). \*\*kwargs : dict Arguments to be sent to clisops. See relevant function for details. Depending on the method, required kwargs are: - gridpoint: lon, lat - bbox: lon_bnds, lat_bnds - shape: shape - sel: slices for each dimension Returns ------- xr.Dataset Subsetted Dataset. See Also -------- clisops.core.subset.subset_gridpoint, clisops.core.subset.subset_bbox, clisops.core.subset.subset_shape """ if tile_buffer > 0 and method in ["gridpoint", "sel"]: warnings.warn( f"tile_buffer is not used for the '{method}' method. Ignoring the argument.", UserWarning, ) if "latitude" not in ds.cf or "longitude" not in ds.cf: ds = ds.cf.guess_coord_axis() if method == "gridpoint": ds_subset = _subset_gridpoint(ds, name=name, **kwargs) elif method == "bbox": ds_subset = _subset_bbox(ds, name=name, tile_buffer=tile_buffer, **kwargs) elif method == "shape": ds_subset = _subset_shape(ds, name=name, tile_buffer=tile_buffer, **kwargs) elif method == "sel": ds_subset = _subset_sel(ds, name=name, **kwargs) else: raise ValueError( "Subsetting type not recognized. Use 'gridpoint', 'bbox', 'shape' or 'sel'." ) return ds_subset
def _subset_gridpoint( ds: xr.Dataset, lon: Union[float, Sequence[float], xr.DataArray], lat: Union[float, Sequence[float], xr.DataArray], *, name: Optional[str] = None, **kwargs, ) -> xr.Dataset: r"""Subset the data to a gridpoint. Parameters ---------- ds : xr.Dataset Dataset to be subsetted. lon : float or Sequence[float] or xr.DataArray Longitude coordinate(s). Must be of the same length as lat. lat : float or Sequence[float] or xr.DataArray Latitude coordinate(s). Must be of the same length as lon. name: str, optional Used to rename the 'cat:domain' attribute. \*\*kwargs : dict Other arguments to be sent to clisops. Possible kwargs are: - start_date (str): Start date for the subset in the format 'YYYY-MM-DD'. - end_date (str): End date for the subset in the format 'YYYY-MM-DD'. - first_level (int or float): First level of the subset. - last_level (int or float): Last level of the subset. - tolerance (float): Masks values if the distance to the nearest gridpoint is larger than tolerance in meters. - add_distance (bool): If True, adds a variable with the distance to the nearest gridpoint. Returns ------- xr.Dataset Subsetted Dataset. """ ds = _load_lon_lat(ds) if not hasattr(lon, "__iter__"): lon = [lon] if not hasattr(lat, "__iter__"): lat = [lat] ds_subset = clisops.core.subset_gridpoint(ds, lon=lon, lat=lat, **kwargs) new_history = ( f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"gridpoint spatial subsetting on {len(lon)} coordinates - clisops v{clisops.__version__}" ) return update_history_and_name(ds_subset, new_history, name) def _subset_bbox( ds: xr.Dataset, lon_bnds: Union[tuple[float, float], list[float]], lat_bnds: Union[tuple[float, float], list[float]], *, name: Optional[str] = None, tile_buffer: float = 0, **kwargs, ) -> xr.Dataset: r"""Subset the data to a bounding box. Parameters ---------- ds : xr.Dataset Dataset to be subsetted. lon_bnds : tuple or list of two floats Longitude boundaries of the bounding box. lat_bnds : tuple or list of two floats Latitude boundaries of the bounding box. name: str, optional Used to rename the 'cat:domain' attribute. tile_buffer: float Uses an approximation of the grid cell size to add a dynamic buffer around the requested region. \*\*kwargs : dict Other arguments to be sent to clisops. Possible kwargs are: - start_date (str): Start date for the subset in the format 'YYYY-MM-DD'. - end_date (str): End date for the subset in the format 'YYYY-MM-DD'. - first_level (int or float): First level of the subset. - last_level (int or float): Last level of the subset. - time_values (Sequence[str]): A list of datetime strings to subset. - level_values (Sequence[int or float]): A list of levels to subset. Returns ------- xr.Dataset Subsetted Dataset. """ ds = _load_lon_lat(ds) if tile_buffer > 0: lon_res, lat_res = _estimate_grid_resolution(ds) lon_bnds = ( lon_bnds[0] - lon_res * tile_buffer, lon_bnds[1] + lon_res * tile_buffer, ) lat_bnds = ( lat_bnds[0] - lat_res * tile_buffer, lat_bnds[1] + lat_res * tile_buffer, ) ds_subset = clisops.core.subset_bbox( ds, lon_bnds=lon_bnds, lat_bnds=lat_bnds, **kwargs ) new_history = ( f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"bbox spatial subsetting with {'buffer=' + str(tile_buffer) if tile_buffer > 0 else 'no buffer'}" f", lon_bnds={np.array(lon_bnds)}, lat_bnds={np.array(lat_bnds)}" f" - clisops v{clisops.__version__}" ) return update_history_and_name(ds_subset, new_history, name) def _subset_shape( ds: xr.Dataset, shape: Union[str, Path, gpd.GeoDataFrame], *, name: Optional[str] = None, tile_buffer: float = 0, **kwargs, ) -> xr.Dataset: r"""Subset the data to a shape. Parameters ---------- ds : xr.Dataset Dataset to be subsetted. shape : str or gpd.GeoDataFrame Path to the shapefile or GeoDataFrame. name: str, optional Used to rename the 'cat:domain' attribute. tile_buffer: float Uses an approximation of the grid cell size to add a buffer around the requested region. \*\*kwargs : dict Other arguments to be sent to clisops. Possible kwargs are: - raster_crs (str or int): EPSG number or PROJ4 string. - shape_crs (str or int): EPSG number or PROJ4 string. - buffer (float): Buffer size to add around the shape. Units are based on the shape degrees/metres. - start_date (str): Start date for the subset in the format 'YYYY-MM-DD'. - end_date (str): End date for the subset in the format 'YYYY-MM-DD'. - first_level (int or float): First level of the subset. - last_level (int or float): Last level of the subset. Returns ------- xr.Dataset Subsetted Dataset. """ ds = _load_lon_lat(ds) if tile_buffer > 0: if kwargs.get("buffer") is not None: raise ValueError( "Both tile_buffer and clisops' buffer were requested. Use only one." ) lon_res, lat_res = _estimate_grid_resolution(ds) kwargs["buffer"] = np.max([lon_res, lat_res]) * tile_buffer ds_subset = clisops.core.subset_shape(ds, shape=shape, **kwargs) new_history = ( f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"shape spatial subsetting with {'buffer=' + str(tile_buffer) if tile_buffer > 0 else 'no buffer'}" f", shape={Path(shape).name if isinstance(shape, (str, Path)) else 'gpd.GeoDataFrame'}" f" - clisops v{clisops.__version__}" ) return update_history_and_name(ds_subset, new_history, name) def _subset_sel(ds: xr.Dataset, *, name: Optional[str] = None, **kwargs) -> xr.Dataset: r"""Subset the data using the .sel() method. Parameters ---------- ds : xr.Dataset Dataset to be subsetted. name: str, optional Used to rename the 'cat:domain' attribute. \*\*kwargs : dict The keys are the dimensions to subset and the values are turned into a slice. Returns ------- xr.Dataset Subsetted Dataset. """ # Create a dictionary with slices for each dimension arg_sel = {dim: slice(*map(float, bounds)) for dim, bounds in kwargs.items()} # Subset the dataset ds_subset = ds.sel(**arg_sel) # Update the history attribute new_history = ( f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"sel subsetting with arguments {arg_sel}" ) return update_history_and_name(ds_subset, new_history, name) def _load_lon_lat(ds: xr.Dataset) -> xr.Dataset: """Load longitude and latitude for more efficient subsetting.""" if xc.core.utils.uses_dask(ds.cf["longitude"]): logger.info("Loading longitude for more efficient subsetting.") (ds[ds.cf["longitude"].name],) = dask.compute(ds[ds.cf["longitude"].name]) if xc.core.utils.uses_dask(ds.cf["latitude"]): logger.info("Loading latitude for more efficient subsetting.") (ds[ds.cf["latitude"].name],) = dask.compute(ds[ds.cf["latitude"].name]) return ds def _estimate_grid_resolution(ds: xr.Dataset) -> tuple[float, float]: # Since this is to compute a buffer, we take the maximum difference as an approximation. # Estimate the grid resolution if len(ds.lon.dims) == 1: # 1D lat-lon lon_res = np.abs(ds.lon.diff("lon").max().values) lat_res = np.abs(ds.lat.diff("lat").max().values) else: lon_res = np.abs(ds.lon.diff(ds.cf["X"].name).max().values) lat_res = np.abs(ds.lat.diff(ds.cf["Y"].name).max().values) return lon_res, lat_res def update_history_and_name(ds_subset, new_history, name): history = ( new_history + " \n " + ds_subset.attrs["history"] if "history" in ds_subset.attrs else new_history ) ds_subset.attrs["history"] = history if name is not None: ds_subset.attrs["cat:domain"] = name return ds_subset