"""Functions to regrid datasets."""
import datetime
import operator
import os
import random
import string
import warnings
from copy import deepcopy
from pathlib import Path
import cartopy.crs as ccrs
import cf_xarray as cfxr
import xarray as xr
from xclim.core.units import convert_units_to
try:
import xesmf as xe
from xesmf.frontend import Regridder
except (ImportError, KeyError) as e:
if isinstance(e, KeyError):
if e.args[0] == "Author":
warnings.warn(
"The xesmf package could not be imported due to a known KeyError bug that occurs with some "
"older versions of ESMF and specific execution setups (such as debugging on a Windows machine). "
"As a workaround, try installing 'importlib-metadata <8.0.0' and/or updating ESMF. If you do not "
"need 'xesmf' functionalities (e.g. regridding), you can ignore this warning.",
stacklevel=2,
)
else:
raise e
xe = None
Regridder = "xesmf.Regridder"
from .config import parse_config
from .spatial import get_crs, get_grid_mapping
__all__ = ["create_bounds_gridmapping", "create_mask", "regrid_dataset"]
[docs]
@parse_config
def regrid_dataset( # noqa: C901
ds: xr.Dataset,
ds_grid: xr.Dataset,
*,
weights_location: str | os.PathLike | None = None,
regridder_kwargs: dict | None = None,
intermediate_grids: dict | None = None,
to_level: str = "regridded",
) -> xr.Dataset:
"""
Regrid a dataset according to weights and a reference grid.
Based on an intake_esm catalog, this function performs regridding on Zarr files.
Parameters
----------
ds : xarray.Dataset
Dataset to regrid. The Dataset needs to have lat/lon coordinates.
Supports a 'mask' variable compatible with ESMF standards.
ds_grid : xr.Dataset
Destination grid. The Dataset needs to have lat/lon coordinates.
Supports a 'mask' variable compatible with ESMF standards.
weights_location : Union[str, os.PathLike], optional
Path to the folder where weight file is saved. Leave as None to force re-computation of weights.
Note that in order to reuse the weights, ds and ds_grid should both have the 'cat:id' and 'cat:domain' attributes.
regridder_kwargs : dict, optional
Arguments to send xe.Regridder(). If it contains `skipna` or `output_chunks`, those
are passed to the regridder call directly.
intermediate_grids : dict, optional
This argument is used to do a regridding in many steps, regridding to regular
grids before regridding to the final ds_grid.
This is useful when there is a large jump in resolution between ds and ds grid.
The format is a nested dictionary shown in Notes.
If None, no intermediary grid is used, there is only a regrid from ds to ds_grid.
to_level : str
The processing level to assign to the output.
Defaults to 'regridded'.
Returns
-------
xarray.Dataset
Regridded dataset.
See Also
--------
xesmf.regridde : Used to perform regridding operations.
xesmf.util.cf_grid_2d : Used to create grids that follow CF conventions.
Notes
-----
intermediate_grids =
{'name_of_inter_grid_1': {'cf_grid_2d': {arguments for util.cf_grid_2d },'regridder_kwargs':{arguments for xe.Regridder}},
'name_of_inter_grid_2': dictionary_as_above}
"""
if xe is None:
raise ImportError("xscen's regridding functionality requires xESMF to work, please install that package.")
ds = ds.copy()
regridder_kwargs = regridder_kwargs or {}
# We modify the dataset later, so we need to keep track of whether it had lon_bounds and lat_bounds to begin with
has_lon_bounds = "lon_bounds" in ds
has_lat_bounds = "lat_bounds" in ds
# Generate unique IDs to name the weights file, but remove the members and experiment from the dataset ID
if weights_location is not None:
dsid = (
ds.attrs.get("cat:id", _generate_random_string(15))
.replace(ds.attrs.get("cat:member", ""), "")
.replace(ds.attrs.get("cat:driving_member", ""), "")
.replace(ds.attrs.get("cat:experiment", ""), "")
)
dsid = f"{dsid}_{ds.attrs.get('cat:domain', _generate_random_string(15))}"
gridid = f"{ds_grid.attrs.get('cat:id', _generate_random_string(15))}_{ds_grid.attrs.get('cat:domain', _generate_random_string(15))}"
ds_grids = [] # List of target grids
reg_arguments = [] # List of accompanying arguments for xe.Regridder()
if intermediate_grids:
for dict_inter in intermediate_grids.values():
reg_arguments.append(dict_inter["regridder_kwargs"])
ds_grids.append(xe.util.cf_grid_2d(**dict_inter["cf_grid_2d"]))
ds_grids.append(ds_grid) # Add the final ds_grid
reg_arguments.append(regridder_kwargs) # Add the final regridder_kwargs
out = None
# If the grid is the same, skip the call to xESMF
if ds["lon"].equals(ds_grid["lon"]) & ds["lat"].equals(ds_grid["lat"]):
out = ds
if "mask" in out:
out = out.where(out.mask == 1)
out = out.drop_vars(["mask"])
if "mask" in ds_grid:
out = out.where(ds_grid.mask == 1)
else:
for i, (ds_grid, regridder_kwargs) in enumerate(zip(ds_grids, reg_arguments, strict=False)):
# If this is not the first iteration (out != None),
# get the result from the last iteration (out) as input
ds = out or ds
kwargs = deepcopy(regridder_kwargs)
# Prepare the weight file
if weights_location is not None:
Path(weights_location).mkdir(parents=True, exist_ok=True)
weights_filename = Path(
weights_location,
f"{dsid}_{gridid}_regrid{i}{'_'.join(kwargs[k] for k in kwargs if isinstance(kwargs[k], str))}.nc",
)
# Reuse existing weight file if possible
if Path(weights_filename).is_file() and not (("reuse_weights" in kwargs) and (kwargs["reuse_weights"] is False)):
kwargs["weights"] = weights_filename
kwargs["reuse_weights"] = True
else:
weights_filename = None
# Extract args that are to be given at call time.
# output_chunks is only valid for xesmf >= 0.8, so don't add it be default to the call_kwargs
call_kwargs = {"skipna": kwargs.pop("skipna", False)}
if "output_chunks" in kwargs:
call_kwargs["output_chunks"] = kwargs.pop("output_chunks")
regridder = _regridder(ds_in=ds, ds_grid=ds_grid, filename=weights_filename, **kwargs)
# The regridder (when fed Datasets) doesn't like if 'mask' is present.
if "mask" in ds:
ds = ds.drop_vars(["mask"])
out = regridder(ds, keep_attrs=True, **call_kwargs)
# Double-check that grid_mapping information is transferred
gridmap_out = get_grid_mapping(ds_grid)
if gridmap_out:
# Regridder seems to seriously mess up the rotated dimensions
for d in out.lon.dims:
out[d] = ds_grid[d]
if d not in out.coords:
out = out.assign_coords({d: ds_grid[d]})
# Add the grid_mapping attribute
for v in out.data_vars:
if any(d in out[v].dims for d in [out.cf.axes["X"][0], out.cf.axes["Y"][0], "loc"]):
out[v].attrs["grid_mapping"] = gridmap_out
# Add the grid_mapping coordinate
if gridmap_out not in out:
out = out.assign_coords({gridmap_out: ds_grid[gridmap_out]})
else:
gridmap_in = get_grid_mapping(ds)
# Remove the original grid_mapping attribute
for v in out.data_vars:
if "grid_mapping" in out[v].attrs:
out[v].attrs.pop("grid_mapping")
# Remove the original grid_mapping coordinate if it is still in the output
out = out.drop_vars(gridmap_in, errors="ignore")
# cf_grid_2d adds temporary variables that we don't want to keep
if "lon_bounds" in out and has_lon_bounds is False:
out = out.drop_vars("lon_bounds")
if "lat_bounds" in out and has_lat_bounds is False:
out = out.drop_vars("lat_bounds")
# History
kwargs_for_hist = deepcopy(regridder_kwargs)
kwargs_for_hist.setdefault("method", regridder.method)
if intermediate_grids and i < len(intermediate_grids):
name_inter = list(intermediate_grids.keys())[i]
cf_grid_2d_args = intermediate_grids[name_inter]["cf_grid_2d"]
new_history = (
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"regridded with regridder arguments {kwargs_for_hist} to a xesmf"
f" cf_grid_2d with arguments {cf_grid_2d_args} - xESMF v{xe.__version__}"
)
else:
new_history = (
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] regridded with arguments {kwargs_for_hist} - xESMF v{xe.__version__}"
)
history = f"{new_history}\n{out.attrs['history']}" if "history" in out.attrs else new_history
out.attrs["history"] = history
out = out.drop_vars("latitude_longitude", errors="ignore")
# Attrs
out.attrs["cat:processing_level"] = to_level
out.attrs["cat:domain"] = ds_grid.attrs["cat:domain"] if "cat:domain" in ds_grid.attrs else None
return out
[docs]
@parse_config
def create_mask(
ds: xr.Dataset | xr.DataArray,
*,
variable: str | None = None,
where_operator: str | None = None,
where_threshold: float | str | None = None,
mask_nans: bool = True,
) -> xr.DataArray:
"""
Create a 0-1 mask based on incoming arguments.
Parameters
----------
ds : xr.Dataset or xr.DataArray
Dataset or DataArray to be evaluated. If a time dimension is present, the first time step will be used.
variable : str, optional
If using a Dataset, the variable on which to base the mask.
where_operator : str, optional
Operator to use for the threshold comparison. One of "<", "<=", "==", "!=", ">=", ">".
Needs to be used with `where_threshold`.
where_threshold : float or str, optional
Threshold value to use for the comparison. A string can be used to reference units, e.g. "10 mm/day".
Needs to be used with `where_operator`.
mask_nans : bool, optional
Whether to mask NaN values in the mask array. Default is True.
Returns
-------
xr.DataArray
Mask array.
"""
# Prepare the mask for the destination grid
ops = {
"<": operator.lt,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
">=": operator.ge,
">": operator.gt,
}
def _cmp(arg1, op, arg2):
operation = ops.get(op)
return operation(arg1, arg2)
if isinstance(ds, xr.Dataset):
if variable is None:
raise ValueError("A variable needs to be specified when passing a Dataset.")
ds = ds[variable].copy()
else:
ds = ds.copy()
if "time" in ds.dims:
ds = ds.isel(time=0)
mask = xr.ones_like(ds)
mask.attrs = {"long_name": "Mask"}
mask.name = "mask"
# Create the mask based on the threshold
if (where_operator is not None and where_threshold is None) or (where_operator is None and where_threshold is not None):
raise ValueError("'where_operator' and 'where_threshold' must be used together.")
if where_threshold is not None:
mask.attrs["where_threshold"] = f"{variable} {where_operator} {where_threshold}"
if isinstance(where_threshold, str):
ds = convert_units_to(ds, where_threshold.split(" ")[1])
where_threshold = float(where_threshold.split(" ")[0])
mask = xr.where(_cmp(ds, where_operator, where_threshold), mask, 0, keep_attrs=True)
# Mask NaNs
if mask_nans:
mask = xr.where(ds.notnull(), mask, 0, keep_attrs=True)
mask.attrs["mask_NaNs"] = "True"
else:
# The where clause above will mask NaNs, so we need to revert that
attrs = mask.attrs
mask = xr.where(ds.isnull(), 1, mask)
mask.attrs = attrs
mask.attrs["mask_NaNs"] = "False"
return mask
def _regridder(
ds_in: xr.Dataset,
ds_grid: xr.Dataset,
*,
filename: str | os.PathLike | None = None,
method: str = "bilinear",
unmapped_to_nan: bool | None = True,
**kwargs,
) -> Regridder:
"""
Call to xesmf Regridder with a few default arguments.
Parameters
----------
ds_in : xr.Dataset
Incoming grid. The Dataset needs to have lat/lon coordinates.
ds_grid : xr.Dataset
Destination grid. The Dataset needs to have lat/lon coordinates.
filename : str or os.PathLike, optional
Path to the NetCDF file with weights information.
method : str
Interpolation method.
unmapped_to_nan : bool, optional
Arguments to send xe.Regridder().
regridder_kwargs : dict
Arguments to send xe.Regridder().
Returns
-------
xe.frontend.Regridder
Regridder object
"""
if method.startswith("conservative"):
gridmap_in = get_grid_mapping(ds_in)
gridmap_grid = get_grid_mapping(ds_grid)
if ds_in.cf["longitude"].ndim == 2 and "longitude" not in ds_in.cf.bounds and gridmap_in in ds_in:
ds_in = ds_in.assign_coords(**create_bounds_gridmapping(ds_in, gridmap_in))
if ds_grid.cf["longitude"].ndim == 2 and "longitude" not in ds_grid.cf.bounds and gridmap_grid in ds_grid:
ds_grid = ds_grid.assign_coords(**create_bounds_gridmapping(ds_grid, gridmap_grid))
regridder = xe.Regridder(
ds_in=ds_in,
ds_out=ds_grid,
method=method,
unmapped_to_nan=unmapped_to_nan,
**kwargs,
)
if filename is not None and not Path(filename).is_file():
regridder.to_netcdf(filename)
return regridder
[docs]
def create_bounds_gridmapping(ds: xr.Dataset, gridmap: str | None = None) -> xr.Dataset:
"""
Create bounds for rotated pole datasets.
Parameters
----------
ds : xr.Dataset
Dataset with a grid mapping coordinate.
gridmap : str, optional
Name of the grid mapping coordinate. If None, it will be inferred from the dataset.
Returns
-------
xr.Dataset
Dataset with bounds coordinates for lat and lon.
"""
if gridmap is None:
gridmap = get_grid_mapping(ds)
if gridmap == "":
raise ValueError("Grid mapping could not be inferred from the dataset.")
if gridmap not in ds:
raise ValueError("Input `gridmap`={gridmap} is not a coordinate of ds.")
xname = ds.cf.axes["X"][0]
yname = ds.cf.axes["Y"][0]
ds = ds.cf.add_bounds([yname, xname])
# In "vertices" format then expand to 2D. From (N, 2) to (N+1,) to (N+1, M+1)
yv1D = cfxr.bounds_to_vertices(ds[f"{yname}_bounds"], "bounds")
xv1D = cfxr.bounds_to_vertices(ds[f"{xname}_bounds"], "bounds")
yv = yv1D.expand_dims(dict([(f"{xname}_vertices", xv1D)])).transpose(f"{xname}_vertices", f"{yname}_vertices")
xv = xv1D.expand_dims(dict([(f"{yname}_vertices", yv1D)])).transpose(f"{xname}_vertices", f"{yname}_vertices")
crs = get_crs(ds[gridmap])
PC = ccrs.PlateCarree(globe=crs.globe)
# Project points
pts = PC.transform_points(crs, xv.values, yv.values)
lonv = xv.copy(data=pts[..., 0]).rename("lon_vertices")
latv = yv.copy(data=pts[..., 1]).rename("lat_vertices")
# Back to CF bounds format. From (N+1, M+1) to (4, N, M)
lonb = cfxr.vertices_to_bounds(lonv, ("bounds", xname, yname)).rename("lon_bounds")
latb = cfxr.vertices_to_bounds(latv, ("bounds", xname, yname)).rename("lat_bounds")
# Create dataset, set coords and attrs
ds_bnds = xr.merge([lonb, latb]).assign(dict([("lon", ds.lon), ("lat", ds.lat), (gridmap, ds[gridmap])]))
ds_bnds[yname] = ds[yname]
ds_bnds[xname] = ds[xname]
# Drop "bounds" attribute added by cf-xarray on 1D coords (we didn't keep these)
ds_bnds[xname].attrs.pop("bounds", "")
ds_bnds[yname].attrs.pop("bounds", "")
ds_bnds.lat.attrs["bounds"] = "lat_bounds"
ds_bnds.lon.attrs["bounds"] = "lon_bounds"
return ds_bnds.transpose(*ds.lon.dims, "bounds")
def _generate_random_string(length: int):
characters = string.ascii_letters + string.digits
# Random seed based on the current time
random.seed(datetime.datetime.now().timestamp())
random_string = "".join(
random.choice(characters) # noqa: S311
for _ in range(length)
)
return random_string