Source code for xscen.regrid

"""Functions to regrid datasets."""

import datetime
import operator
import os
import warnings
from copy import deepcopy
from typing import Optional, Union

import cartopy.crs as ccrs
import cf_xarray as cfxr
import numpy as np
import xarray as xr

try:
    import xesmf as xe
    from xesmf.frontend import Regridder
except ImportError:
    xe = None
    Regridder = "xesmf.Regridder"

from .config import parse_config

# TODO: Implement logging, warnings, etc.
# TODO: Add an option to call xesmf.util.grid_2d or xesmf.util.grid_global
# TODO: Implement support for an OBS2SIM kind of interpolation


__all__ = ["create_mask", "regrid_dataset"]


[docs] @parse_config def regrid_dataset( # noqa: C901 ds: xr.Dataset, ds_grid: xr.Dataset, weights_location: Union[str, os.PathLike], *, regridder_kwargs: Optional[dict] = None, intermediate_grids: Optional[dict] = None, to_level: str = "regridded", ) -> xr.Dataset: """Regrid a dataset according to weights and a reference grid. Based on an intake_esm catalog, this function performs regridding on Zarr files. Parameters ---------- ds : xarray.Dataset Dataset to regrid. The Dataset needs to have lat/lon coordinates. Supports a 'mask' variable compatible with ESMF standards. weights_location : Union[str, os.PathLike] Path to the folder where weight file is saved. ds_grid : xr.Dataset Destination grid. The Dataset needs to have lat/lon coordinates. Supports a 'mask' variable compatible with ESMF standards. regridder_kwargs : dict, optional Arguments to send xe.Regridder(). If it contains `skipna` or `output_chunks`, those are passed to the regridder call directly. intermediate_grids : dict, optional This argument is used to do a regridding in many steps, regridding to regular grids before regridding to the final ds_grid. This is useful when there is a large jump in resolution between ds and ds grid. The format is a nested dictionary shown in Notes. If None, no intermediary grid is used, there is only a regrid from ds to ds_grid. to_level : str The processing level to assign to the output. Defaults to 'regridded' Returns ------- xarray.Dataset Regridded dataset Notes ----- intermediate_grids = {'name_of_inter_grid_1': {'cf_grid_2d': {arguments for util.cf_grid_2d },'regridder_kwargs':{arguments for xe.Regridder}}, 'name_of_inter_grid_2': dictionary_as_above} See Also -------- xesmf.regridder, xesmf.util.cf_grid_2d """ if xe is None: raise ImportError( "xscen's regridding functionality requires xESMF to work, please install that package." ) regridder_kwargs = regridder_kwargs or {} ds_grids = [] # list of target grids reg_arguments = [] # list of accompanying arguments for xe.Regridder() if intermediate_grids: for name_inter, dict_inter in intermediate_grids.items(): reg_arguments.append(dict_inter["regridder_kwargs"]) ds_grids.append(xe.util.cf_grid_2d(**dict_inter["cf_grid_2d"])) ds_grids.append(ds_grid) # add final ds_grid reg_arguments.append(regridder_kwargs) # add final regridder_kwargs out = None # Whether regridding is required if ds["lon"].equals(ds_grid["lon"]) & ds["lat"].equals(ds_grid["lat"]): out = ds if "mask" in out: out = out.where(out.mask == 1) out = out.drop_vars(["mask"]) else: for i, (ds_grid, regridder_kwargs) in enumerate(zip(ds_grids, reg_arguments)): # if this is not the first iteration (out != None), # get result from last iteration (out) as input ds = out or ds kwargs = deepcopy(regridder_kwargs) # if weights_location does no exist, create it if not os.path.exists(weights_location): os.makedirs(weights_location) id = ds.attrs["cat:id"] if "cat:id" in ds.attrs else "weights" # give unique name to weights file weights_filename = os.path.join( weights_location, f"{id}_regrid{i}" f"{'_'.join(kwargs[k] for k in kwargs if isinstance(kwargs[k], str))}.nc", ) # Re-use existing weight file if possible if os.path.isfile(weights_filename) and not ( ("reuse_weights" in kwargs) and (kwargs["reuse_weights"] is False) ): kwargs["weights"] = weights_filename kwargs["reuse_weights"] = True # Extract args that are to be given at call time. # output_chunks is only valid for xesmf >= 0.8, so don't add it be default to the call_kwargs call_kwargs = {"skipna": regridder_kwargs.pop("skipna", False)} if "output_chunks" in regridder_kwargs: call_kwargs["output_chunks"] = regridder_kwargs.pop("output_chunks") regridder = _regridder( ds_in=ds, ds_grid=ds_grid, filename=weights_filename, **regridder_kwargs ) # The regridder (when fed Datasets) doesn't like if 'mask' is present. if "mask" in ds: ds = ds.drop_vars(["mask"]) out = regridder(ds, keep_attrs=True, **call_kwargs) # double-check that grid_mapping information is transferred gridmap_out = any( "grid_mapping" in ds_grid[da].attrs for da in ds_grid.data_vars ) if gridmap_out: gridmap = np.unique( [ ds_grid[da].attrs["grid_mapping"] for da in ds_grid.data_vars if "grid_mapping" in ds_grid[da].attrs and ds_grid[da].attrs["grid_mapping"] in ds_grid ] ) if len(gridmap) != 1: warnings.warn( "Could not determine and transfer grid_mapping information." ) else: # Add the grid_mapping attribute for v in out.data_vars: out[v].attrs["grid_mapping"] = gridmap[0] # Add the grid_mapping coordinate if gridmap[0] not in out: out = out.assign_coords({gridmap[0]: ds_grid[gridmap[0]]}) # Regridder seems to seriously mess up the rotated dimensions for d in out.lon.dims: out[d] = ds_grid[d] if d not in out.coords: out = out.assign_coords({d: ds_grid[d]}) else: gridmap = np.unique( [ ds[da].attrs["grid_mapping"] for da in ds.data_vars if "grid_mapping" in ds[da].attrs ] ) # Remove the original grid_mapping attribute for v in out.data_vars: if "grid_mapping" in out[v].attrs: out[v].attrs.pop("grid_mapping") # Remove the original grid_mapping coordinate if it is still in the output out = out.drop_vars(set(gridmap).intersection(out.variables)) # History kwargs_for_hist = deepcopy(regridder_kwargs) kwargs_for_hist.setdefault("method", regridder.method) if intermediate_grids and i < len(intermediate_grids): name_inter = list(intermediate_grids.keys())[i] cf_grid_2d_args = intermediate_grids[name_inter]["cf_grid_2d"] new_history = ( f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"regridded with regridder arguments {kwargs_for_hist} to a xesmf" f" cf_grid_2d with arguments {cf_grid_2d_args} - xESMF v{xe.__version__}" ) else: new_history = ( f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"regridded with arguments {kwargs_for_hist} - xESMF v{xe.__version__}" ) history = ( f"{new_history}\n{out.attrs['history']}" if "history" in out.attrs else new_history ) out.attrs["history"] = history out = out.drop_vars("latitude_longitude", errors="ignore") # Attrs out.attrs["cat:processing_level"] = to_level out.attrs["cat:domain"] = ( ds_grid.attrs["cat:domain"] if "cat:domain" in ds_grid.attrs else None ) return out
[docs] @parse_config def create_mask(ds: Union[xr.Dataset, xr.DataArray], mask_args: dict) -> xr.DataArray: """Create a 0-1 mask based on incoming arguments. Parameters ---------- ds : xr.Dataset or xr.DataArray Dataset or DataArray to be evaluated mask_args : dict Instructions to build the mask (required fields listed in the Notes). Note ---- 'mask' fields: variable: str, optional Variable on which to base the mask, if ds_mask is not a DataArray. where_operator: str, optional Conditional operator such as '>' where_threshold: str, optional Value threshold to be used in conjunction with where_operator. mask_nans: bool Whether to apply a mask on NaNs. Returns ------- xr.DataArray Mask array. """ # Prepare the mask for the destination grid ops = { "<": operator.lt, "<=": operator.le, "==": operator.eq, "!=": operator.ne, ">=": operator.ge, ">": operator.gt, } def cmp(arg1, op, arg2): operation = ops.get(op) return operation(arg1, arg2) mask_args = mask_args or {} if isinstance(ds, xr.DataArray): mask = ds elif isinstance(ds, xr.Dataset) and "variable" in mask_args: mask = ds[mask_args["variable"]] else: raise ValueError("Could not determine what to base the mask on.") if "time" in mask.dims: mask = mask.isel(time=0) if "where_operator" in mask_args: mask = xr.where( cmp(mask, mask_args["where_operator"], mask_args["where_threshold"]), 1, 0 ) else: mask = xr.ones_like(mask) if ("mask_nans" in mask_args) & (mask_args["mask_nans"] is True): mask = mask.where(np.isreal(mask), other=0) # Attributes if "where_operator" in mask_args: mask.attrs["where_threshold"] = ( f"{mask_args['variable']} {mask_args['where_operator']} {mask_args['where_threshold']}" ) mask.attrs["mask_nans"] = f"{mask_args['mask_nans']}" return mask
def _regridder( ds_in: xr.Dataset, ds_grid: xr.Dataset, filename: Union[str, os.PathLike], *, method: str = "bilinear", unmapped_to_nan: Optional[bool] = True, **kwargs, ) -> Regridder: """Call to xesmf Regridder with a few default arguments. Parameters ---------- ds_in : xr.Dataset Incoming grid. The Dataset needs to have lat/lon coordinates. ds_grid : xr.Dataset Destination grid. The Dataset needs to have lat/lon coordinates. filename : str or os.PathLike Path to the NetCDF file with weights information. method : str Interpolation method. unmapped_to_nan : bool, optional Arguments to send xe.Regridder(). regridder_kwargs : dict Arguments to send xe.Regridder(). Returns ------- xe.frontend.Regridder Regridder object """ if method.startswith("conservative"): if ( ds_in.cf["longitude"].ndim == 2 and "longitude" not in ds_in.cf.bounds and "rotated_pole" in ds_in ): ds_in = ds_in.update(create_bounds_rotated_pole(ds_in)) if ( ds_grid.cf["longitude"].ndim == 2 and "longitude" not in ds_grid.cf.bounds and "rotated_pole" in ds_grid ): ds_grid = ds_grid.update(create_bounds_rotated_pole(ds_grid)) regridder = xe.Regridder( ds_in=ds_in, ds_out=ds_grid, method=method, unmapped_to_nan=unmapped_to_nan, **kwargs, ) if ~os.path.isfile(filename): regridder.to_netcdf(filename) return regridder def create_bounds_rotated_pole(ds: xr.Dataset): """Create bounds for rotated pole datasets.""" ds = ds.cf.add_bounds(["rlat", "rlon"]) # In "vertices" format then expand to 2D. From (N, 2) to (N+1,) to (N+1, M+1) rlatv1D = cfxr.bounds_to_vertices(ds.rlat_bounds, "bounds") rlonv1D = cfxr.bounds_to_vertices(ds.rlon_bounds, "bounds") rlatv = rlatv1D.expand_dims(rlon_vertices=rlonv1D).transpose( "rlon_vertices", "rlat_vertices" ) rlonv = rlonv1D.expand_dims(rlat_vertices=rlatv1D).transpose( "rlon_vertices", "rlat_vertices" ) # Get cartopy's crs for the projection RP = ccrs.RotatedPole( pole_longitude=ds.rotated_pole.grid_north_pole_longitude, pole_latitude=ds.rotated_pole.grid_north_pole_latitude, central_rotated_longitude=ds.rotated_pole.north_pole_grid_longitude, ) PC = ccrs.PlateCarree() # Project points pts = PC.transform_points(RP, rlonv.values, rlatv.values) lonv = rlonv.copy(data=pts[..., 0]).rename("lon_vertices") latv = rlatv.copy(data=pts[..., 1]).rename("lat_vertices") # Back to CF bounds format. From (N+1, M+1) to (4, N, M) lonb = cfxr.vertices_to_bounds(lonv, ("bounds", "rlon", "rlat")).rename( "lon_bounds" ) latb = cfxr.vertices_to_bounds(latv, ("bounds", "rlon", "rlat")).rename( "lat_bounds" ) # Create dataset, set coords and attrs ds_bnds = xr.merge([lonb, latb]).assign( lon=ds.lon, lat=ds.lat, rotated_pole=ds.rotated_pole ) ds_bnds["rlat"] = ds.rlat ds_bnds["rlon"] = ds.rlon ds_bnds.lat.attrs["bounds"] = "lat_bounds" ds_bnds.lon.attrs["bounds"] = "lon_bounds" return ds_bnds.transpose(*ds.lon.dims, "bounds")