"""Common utilities to be used in many places."""
import fnmatch
import gettext
import json
import logging
import os
import re
import warnings
from collections import defaultdict
from collections.abc import Sequence
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime
from itertools import chain
from pathlib import Path
from types import ModuleType
import cftime
import flox.xarray
import numpy as np
import pandas as pd
import xarray as xr
import xsdba
from xarray.coding import cftime_offsets as cfoff
from xclim.core import units
from xclim.core.calendar import parse_offset
from xclim.core.options import METADATA_LOCALES
from xclim.core.options import OPTIONS as XC_OPTIONS
from xclim.core.utils import uses_dask
from .config import parse_config
logger = logging.getLogger(__name__)
__all__ = [
"CV",
"add_attr",
"change_units",
"clean_up",
"date_parser",
"ensure_correct_time",
"ensure_new_xrfreq",
"get_cat_attrs",
"maybe_unstack",
"minimum_calendar",
"natural_sort",
"stack_dates",
"stack_drop_nans",
"standardize_periods",
"translate_time_chunk",
"unstack_dates",
"unstack_fill_nan",
"update_attr",
"xclim_convert_units_to",
"xrfreq_to_timedelta",
]
TRANSLATOR = defaultdict(lambda: lambda s: s)
"""Dictionary of translating objects.
Each key is a two letter locale code and values are functions that return the translated message
as compiled in the gettext catalogs. If a language is not defined or a message not translated,
the function will return the raw message.
"""
try:
for loc in (Path(__file__).parent / "data").iterdir():
if loc.is_dir() and len(loc.name) == 2:
TRANSLATOR[loc.name] = gettext.translation("xscen", localedir=loc.parent, languages=[loc.name]).gettext
except FileNotFoundError as err:
raise ImportError("Your xscen installation doesn't have compiled translations. Run `make translate` from the source directory to fix.") from err
[docs]
def update_attr(
ds: xr.Dataset | xr.DataArray,
attr: str,
new: str,
others: Sequence[xr.Dataset | xr.DataArray] | None = None,
**fmt,
) -> xr.Dataset | xr.DataArray:
r"""
Format an attribute referencing itself in a translatable way.
Parameters
----------
ds : Dataset or DataArray
The input object with the attribute to update.
attr : str
Attribute name.
new : str
New attribute as a template string. It may refer to the old version
of the attribute with the "{attr}" field.
others : Sequence of Datasets or DataArrays
Other objects from which we can extract the attribute `attr`.
These can be referenced as "{attrXX}" in `new`, where XX is the based-1 index of the other source in `others`.
If they don't have the `attr` attribute, an empty string is sent to the string formatting.
See notes.
**fmt
Other formatting data.
Returns
-------
Dataset or DataArray
`ds`, but updated with the new version of `attr`, in each of the activated languages.
Notes
-----
This is meant for constructing attributes by extending a previous version
or combining it from different sources. For example, given a `ds` that has `long_name="Variability"`:
>>> update_attr(ds, "long_name", _("Mean of {attr}"))
Will update the "long_name" of `ds` with `long_name="Mean of Variability"`.
The use of `_(...)` allows the detection of this string by the translation manager. The function
will be able to add a translatable version of the string for each activated language, for example adding
a `long_name_fr="Moyenne de Variabilité"` (assuming a `long_name_fr` was present on the initial `ds`).
If the new attribute is an aggregation from multiple sources, these can be passed in `others`.
>>> update_attr(
... ds0,
... "long_name",
... _("Addition of {attr} and {attr1}, divided by {attr2}"),
... others=[ds1, ds2],
... )
Here, `ds0` will have it's `long_name` updated with the passed string, where `attr1` is the `long_name` of `ds1`
and `attr2` the `long_name` of `ds2`. The process will be repeated for each localized `long_name` available on `ds0`.
For example, if `ds0` has a `long_name_fr`, the template string is translated and
filled with the `long_name_fr` attributes of `ds0`, `ds1` and `ds2`.
If the latter don't exist, the english version is used instead.
"""
others = others or []
# .strip(' .') removes trailing and leading whitespaces and dots
if attr in ds.attrs:
others_attrs = {f"attr{i}": dso.attrs.get(attr, "").strip(" .") for i, dso in enumerate(others, 1)}
ds.attrs[attr] = new.format(attr=ds.attrs[attr].strip(" ."), **others_attrs, **fmt)
# All existing locales
for key in fnmatch.filter(ds.attrs.keys(), f"{attr}_??"):
loc = key[-2:]
others_attrs = {f"attr{i}": dso.attrs.get(key, dso.attrs.get(attr, "")).strip(" .") for i, dso in enumerate(others, 1)}
ds.attrs[key] = TRANSLATOR[loc](new).format(attr=ds.attrs[key].strip(" ."), **others_attrs, **fmt)
[docs]
def add_attr(ds: xr.Dataset | xr.DataArray, attr: str, new: str, **fmt):
r"""
Add a formatted translatable attribute to a Dataset, modifying it in place.
Parameters
----------
ds : Dataset or DataArray
The input object to which the attribute will be added.
attr : str
Attribute name.
new : str
New attribute as a template string.
**fmt
Formatting data to fill the template string.
"""
ds.attrs[attr] = new.format(**fmt)
for loc in XC_OPTIONS[METADATA_LOCALES]:
ds.attrs[f"{attr}_{loc}"] = TRANSLATOR[loc](new).format(**fmt)
[docs]
def date_parser( # noqa: C901
date: str | cftime.datetime | pd.Timestamp | datetime | pd.Period,
*,
end_of_period: bool | str = False,
out_dtype: str = "datetime",
strtime_format: str = "%Y-%m-%d",
freq: str = "h",
) -> str | pd.Period | pd.Timestamp:
"""
Return a datetime from a string.
Parameters
----------
date : str, cftime.datetime, pd.Timestamp, datetime.datetime, pd.Period
Date to be converted.
end_of_period : bool or str
If 'YE' or 'ME', the returned date will be the end of the year or month that contains the received date.
If True, the period is inferred from the date's precision, but `date` must be a string,
otherwise nothing is done.
out_dtype : str
Choices are 'datetime', 'period' or 'str'.
strtime_format : str
If out_dtype=='str', this sets the strftime format.
freq : str
If out_dtype=='period', this sets the frequency of the period.
Returns
-------
pd.Timestamp, pd.Period, str
Parsed date.
"""
# Formats, ordered depending on string length
fmts = {
4: ["%Y"],
6: ["%Y%m"],
7: ["%Y-%m"],
8: ["%Y%m%d"],
10: ["%Y%m%d%H", "%Y-%m-%d"],
12: ["%Y%m%d%H%M"],
16: ["%Y-%m-%d %H:%M"],
19: ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"],
}
def _parse_date(date, fmts):
for fmt in fmts:
try:
# `pd.to_datetime` fails with out-of-bounds
s = datetime.strptime(date, fmt)
except ValueError:
pass
else:
match = fmt
break
else:
raise ValueError(f"Can't parse date {date} with formats {fmts}.")
return s, match
fmt = None
# Timestamp can parse a few date formats by default, but not the ones without spaces
# So we try a few known formats first, then a plain call
# Also we need "fmt" to know the precision of the string (if end_of_period is True)
if isinstance(date, str):
try:
date, fmt = _parse_date(date, fmts[len(date)])
except (KeyError, ValueError):
try:
date = pd.Timestamp(date)
except (ValueError, pd._libs.tslibs.parsing.DateParseError):
date = pd.NaT
elif isinstance(date, cftime.datetime):
for n in range(3):
try:
date = pd.Timestamp((date - pd.Timedelta(n, "D")).isoformat())
except ValueError: # We are NOT catching OutOfBoundsDatetime.
pass
else:
break
else:
raise ValueError("Unable to parse cftime date {date}, even when moving back 2 days.")
elif isinstance(date, pd.Period):
# Pandas, you're a mess: Period.to_timestamp() fails for out-of-bounds dates (<1677, > 2242), but not when parsing a string...
date = pd.Timestamp(date.strftime("%Y-%m-%dT%H:%M:%S"))
if not isinstance(date, pd.Timestamp):
date = pd.Timestamp(date)
if isinstance(end_of_period, str) or (end_of_period is True and fmt):
quasiday = (pd.Timedelta(1, "D") - pd.Timedelta(1, "s")).as_unit(date.unit)
if end_of_period in ["Y", "YE"] or "m" not in fmt:
date = pd.tseries.frequencies.to_offset("YE-DEC").rollforward(date) + quasiday
elif end_of_period in ["M", "ME"] or "d" not in fmt:
date = pd.tseries.frequencies.to_offset("ME").rollforward(date) + quasiday
# TODO: Implement subdaily ?
if out_dtype == "str":
return date.strftime(strtime_format)
elif out_dtype == "period":
return date.to_period(freq)
else:
return date
[docs]
def minimum_calendar(*calendars) -> str:
r"""
Return the minimum calendar from a list.
Uses the hierarchy: 360_day < noleap < standard < all_leap, and returns one of those names.
Parameters
----------
*calendars : str
Calendar names, given as successive arguments.
Returns
-------
str
The calendar with the least number of days in a year.
"""
# Unwrap any lists or tuples given in the input, but without destroying strings.
calendars = [[cal] if isinstance(cal, str) else cal for cal in calendars]
calendars = list(chain(*calendars))
# Raise an error if the calendars are not recognized
unknowns = set(calendars).difference(
[
"360_day",
"365_day",
"noleap",
"standard",
"default",
"all_leap",
"366_day",
"gregorian",
"proleptic_gregorian",
]
)
if unknowns:
warnings.warn(
f"These calendars are not recognized: {unknowns}. Results may be incorrect.",
stacklevel=2,
)
if "360_day" in calendars:
out = "360_day"
elif "noleap" in calendars or "365_day" in calendars:
out = "noleap"
elif all(cal in ["all_leap", "366_day"] for cal in calendars):
out = "all_leap"
else:
out = "standard"
return out
[docs]
def translate_time_chunk(chunks: dict, calendar: str, timesize: int) -> dict:
"""
Translate chunk specification for time into a number.
Parameters
----------
chunks : dict
Dictionary specifying the chunk sizes for each dimension. The time dimension can be specified as:
-1 : translates to `timesize`
'Nyear' : translates to N times the number of days in a year of the given calendar.
calendar : str
The calendar type (e.g., 'noleap', '360_day', 'all_leap').
timesize : int
The size of the time dimension.
Returns
-------
dict
The updated chunks dictionary with the time dimension translated to a number.
Notes
-----
-1 translates to `timesize`
'Nyear' translates to N times the number of days in a year of calendar `calendar`.
"""
for k, v in chunks.items():
if isinstance(v, dict):
chunks[k] = translate_time_chunk(v.copy(), calendar, timesize)
elif k == "time" and v is not None:
if isinstance(v, str) and v.endswith("year"):
n = int(chunks["time"].split("year")[0])
nt = n * {
"noleap": 365,
"365_day": 365,
"360_day": 360,
"all_leap": 366,
"366_day": 366,
}.get(calendar, 365.25)
if nt != int(nt):
warnings.warn(
f"The number of days in {chunks['time']} for calendar {calendar} is not an integer. "
f"Chunks will not align perfectly with year ends.",
stacklevel=2,
)
chunks[k] = int(nt)
elif v == -1:
chunks[k] = timesize
return chunks
[docs]
@parse_config
def stack_drop_nans(
ds: xr.Dataset,
mask: xr.DataArray | list[str],
*,
new_dim: str = "loc",
to_file: str | None = None,
) -> xr.Dataset:
"""
Stack dimensions into a single axis and drops indexes where the mask is false.
Parameters
----------
ds : xr.Dataset
A dataset with the same coords as `mask`.
mask : xr.DataArray or list of str
A boolean DataArray with True on the points to keep. The mask will be loaded within this function, but not the dataset.
Alternatively, a list of dimension names to stack. In this case, a mask will be created by loading all data and checking for NaNs.
The latter is not recommended for large datasets.
new_dim : str
The name of the new stacked dim.
to_file : str, optional
A netCDF filename where to write the stacked coords for use in `unstack_fill_nan`.
If given a string with {shape} and {domain}, the formatting will fill them with
the original shape of the dataset and the global attributes 'cat:domain'.
If None (default), nothing is written to disk.
It is recommended to fill this argument in the config. It will be parsed automatically.
E.g.:
utils:
stack_drop_nans:
to_file: /some_path/coords/coords_{domain}_{shape}.nc
unstack_fill_nan:
coords: /some_path/coords/coords_{domain}_{shape}.nc
Returns
-------
xr.Dataset
Same as `ds`, but all dimensions of mask have been stacked to a single `new_dim`.
Indexes where mask is False have been dropped.
See Also
--------
unstack_fill_nan : The inverse operation.
"""
if isinstance(mask, xr.DataArray):
mask_1d = mask.stack({new_dim: mask.dims})
out = ds.stack({new_dim: mask.dims}).where(mask_1d, drop=True)
else:
mask = ds.coords.to_dataset().drop_vars([v for v in ds.coords if not any(d in mask for d in ds[v].dims)])
mask = xr.DataArray(
np.ones(list(mask.sizes.values())), dims=mask.dims, coords=mask.coords
) # Make it a DataArray to fit the rest of the function
out = ds.stack({new_dim: mask.dims}).dropna(new_dim, how="all")
out = out.reset_index(new_dim)
for dim in mask.dims:
out[dim].attrs.update(ds[dim].attrs)
original_shape = "x".join(map(str, mask.shape))
if to_file is not None:
# Set default path to store the information necessary to unstack
# The name includes the domain and the original shape to uniquely identify the dataset
domain = ds.attrs.get("cat:domain", "unknown")
to_file = to_file.format(domain=domain, shape=original_shape)
if not Path(to_file).parent.exists():
Path(to_file).parent.mkdir(exist_ok=True, parents=True)
# Add all coordinates that might have been affected by the stack
mask = mask.assign_coords({c: ds[c] for c in ds.coords if any(d in mask.dims for d in ds[c].dims)})
mask.coords.to_dataset().to_netcdf(to_file)
# Carry information about original shape to be able to unstack properly
for dim in mask.dims:
out[dim].attrs["original_shape"] = original_shape
# this is needed to fix a bug in xarray '2022.6.0'
out[dim] = xr.DataArray(
out[dim].values,
dims=out[dim].dims,
coords=out[dim].coords,
attrs=out[dim].attrs,
)
return out
[docs]
@parse_config
def unstack_fill_nan(
ds: xr.Dataset,
*,
dim: str = "loc",
coords: None | (str | os.PathLike | Sequence[str | os.PathLike] | dict[str, xr.DataArray]) = None,
):
"""
Unstack a Dataset that was stacked by :py:func:`stack_drop_nans`.
Parameters
----------
ds : xr.Dataset
A dataset with some dimensions stacked by `stack_drop_nans`.
dim : str
The dimension to unstack, same as `new_dim` in `stack_drop_nans`.
coords : str or os.PathLike or Sequence or dict, optional
Additional information used to reconstruct coordinates that might have been lost in the stacking (e.g., if a lat/lon grid was all NaNs).
If a string or os.PathLike : Path to a dataset containing only those coordinates, such as the output of `to_file` in `stack_drop_nans`.
This is the recommended option.
If a dictionary : A mapping from the name of the coordinate that was stacked to a DataArray. Better alternative if no file is available.
If a sequence : The names of the original dimensions that were stacked. Worst option.
If None (default), same as a sequence, but all coordinates that have `dim` as a single dimension are used as the new dimensions.
See Notes for more information.
Returns
-------
xr.Dataset
Same as `ds`, but `dim` has been unstacked to coordinates in `coords`.
Missing elements are filled according to the defaults of `fill_value` of :py:meth:`xarray.Dataset.unstack`.
Notes
-----
Some information might have been completely lost in the stacking process, for example, if a longitude is NaN across all latitudes.
It is impossible to recover that information when using `coords` as a list, which is why it is recommended to use a file or a dictionary instead.
If a dictionary is used, the keys must be the names of the coordinates that were stacked and the values must be the DataArrays.
This method can recover both dimensions and additional coordinates that were not dimensions in the original dataset, but were stacked.
If the original stacking was done with `stack_drop_nans` and the `to_file` argument was used, the `coords` argument should be a string with
the path to the file. Additionally, the file name can contain the formatting fields {shape} and {domain}, which will be automatically filled
with the original shape of the dataset and the global attribute 'cat:domain'. If using that dynamic path, it is recommended to fill the
argument in the xscen config.
E.g.:
utils:
stack_drop_nans:
to_file: /some_path/coords/coords_{domain}_{shape}.nc
unstack_fill_nan:
coords: /some_path/coords/coords_{domain}_{shape}.nc
"""
if coords is None:
logger.info("Dataset unstacked using no coords argument.")
coords = [d for d in ds.coords if ds[d].dims == (dim,)]
if isinstance(coords, str | os.PathLike):
# find original shape in the attrs of one of the dimension
original_shape = "unknown"
for c in ds.coords:
if "original_shape" in ds[c].attrs:
original_shape = ds[c].attrs["original_shape"]
domain = ds.attrs.get("cat:domain", "unknown")
coords = coords.format(domain=domain, shape=original_shape)
msg = f"Dataset unstacked using {coords}."
logger.info(msg)
coords = xr.open_dataset(coords)
# separate coords that are dims or not
coords_and_dims = {name: x for name, x in coords.coords.items() if name in coords.dims}
coords_not_dims = {name: x for name, x in coords.coords.items() if name not in coords.dims}
dims, crds = zip(
*[(name, crd.load().values) for name, crd in ds.coords.items() if crd.dims == (dim,) and name in coords_and_dims], strict=False
)
mindex_obj = pd.MultiIndex.from_arrays(crds, names=dims)
mindex_coords = xr.Coordinates.from_pandas_multiindex(mindex_obj, dim)
out = ds.drop_vars(dims).assign_coords(mindex_coords).unstack(dim)
# only reindex with the dims
out = out.reindex(**coords_and_dims)
# add back the coords that aren't dims
for c in coords_not_dims:
out[c] = coords[c]
else:
coord_not_dim = {}
# Special case where the dictionary contains both dimensions and other coordinates
if isinstance(coords, dict):
coord_not_dim = {k: v for k, v in coords.items() if len(set(v.dims).intersection(list(coords))) != 1}
coords = deepcopy(coords)
coords = {k: v for k, v in coords.items() if k in set(coords).difference(coord_not_dim)}
dims, crds = zip(
*[(name, crd.load().values) for name, crd in ds.coords.items() if (crd.dims == (dim,) and name in set(coords))], strict=False
)
# Reconstruct the dimensions
mindex_obj = pd.MultiIndex.from_arrays(crds, names=dims)
mindex_coords = xr.Coordinates.from_pandas_multiindex(mindex_obj, dim)
out = ds.drop_vars(dims).assign_coords(mindex_coords).unstack(dim)
if isinstance(coords, dict):
# Reindex with the coords that were dimensions
out = out.reindex(**coords)
# Add back the coordinates that aren't dimensions
for c in coord_not_dim:
out[c] = coord_not_dim[c]
# Reorder the dimensions to match the CF conventions
order = [out.cf.axes.get(d, [""])[0] for d in ["T", "Z", "Y", "X"]]
order = [d for d in order if d] + [d for d in out.dims if d not in order]
out = out.transpose(*order)
for dim in dims:
out[dim].attrs.update(ds[dim].attrs)
return out
[docs]
def natural_sort(_list: list[str]) -> list[str]:
"""
For strings of numbers. alternative to sorted() that detects a more natural order.
e.g. [r3i1p1, r1i1p1, r10i1p1] is sorted as [r1i1p1, r3i1p1, r10i1p1] instead of [r10i1p1, r1i1p1, r3i1p1]
Parameters
----------
_list : list of str
The list to sort.
Returns
-------
list[str]
The sorted list.
"""
convert = lambda text: int(text) if text.isdigit() else text.lower() # noqa: E731
alphanum_key = lambda key: [ # noqa: E731
convert(c) for c in re.split("([0-9]+)", key)
]
return sorted(_list, key=alphanum_key)
[docs]
def get_cat_attrs(ds: xr.Dataset | xr.DataArray | dict, prefix: str = "cat:", var_as_str=False) -> dict:
"""
Return the catalog-specific attributes from a dataset or dictionary.
Parameters
----------
ds : xr.Dataset, dict
Dataset to be parsed. If a dictionary, it is assumed to be the attributes of the dataset (ds.attrs).
prefix : str
Prefix automatically generated by intake-esm. With xscen, this should be 'cat:'.
var_as_str : bool
If True, 'variable' will be returned as a string if there is only one.
Returns
-------
dict
Compilation of all attributes in a dictionary.
"""
if isinstance(ds, xr.Dataset | xr.DataArray):
attrs = ds.attrs
else:
attrs = ds
facets = {k[len(prefix) :]: v for k, v in attrs.items() if k.startswith(f"{prefix}")}
# to be usable in a path
if var_as_str and "variable" in facets and not isinstance(facets["variable"], str) and len(facets["variable"]) == 1:
facets["variable"] = facets["variable"][0]
return facets
def strip_cat_attrs(ds: xr.Dataset, prefix: str = "cat:") -> xr.Dataset:
"""
Remove attributes added from the catalog by `to_dataset` or `extract_dataset`.
Parameters
----------
ds : xr.Dataset
Dataset to be parsed.
prefix : str
Prefix automatically generated by intake-esm. With xscen, this should be 'cat:'.
Returns
-------
xr.Dataset
Dataset with all attributes starting with `prefix` removed.
"""
dsc = ds.copy()
for k in list(dsc.attrs):
if k.startswith(prefix):
del dsc.attrs[k]
return dsc
[docs]
@parse_config
def maybe_unstack(
ds: xr.Dataset,
dim: str | None = "loc",
coords: str | None = None,
rechunk: dict | None = None,
stack_drop_nans: bool = False,
) -> xr.Dataset:
"""
If stack_drop_nans is True, unstack and rechunk.
Parameters
----------
ds : xr.Dataset
Dataset to unstack.
dim : str, optional
Dimension to unstack.
coords : str, optional
Path to a dataset containing the coords to unstack (and only those).
rechunk : dict, optional
If not None, rechunk the dataset after unstacking.
stack_drop_nans : bool
If True, unstack the dataset and rechunk it.
If False, do nothing.
Returns
-------
xr.Dataset
Unstacked dataset.
"""
if stack_drop_nans:
ds = unstack_fill_nan(ds, dim=dim, coords=coords)
if rechunk is not None:
ds = ds.chunk(rechunk)
return ds
# Read CVs and fill a virtual module
CV = ModuleType(
"CV",
(
"""
Mappings of (controlled) vocabulary. This module is generated automatically
from json files in xscen/CVs. Functions are essentially mappings, most of
which are meant to provide translations between columns.\n\n
Json files must be shallow dictionaries to be supported. If the json file
contains a ``is_regex: True`` entry, then the keys are automatically
translated as regex patterns and the function returns the value of the first
key that matches the pattern. Otherwise the function essentially acts like a
normal dictionary. The 'raw' data parsed from the json file is added in the
``dict`` attribute of the function.
Example:
.. code-block:: python
xs.utils.CV.frequency_to_timedelta.dict
.. literalinclude:: ../src/xscen/CVs/frequency_to_timedelta.json
:language: json
:caption: frequency_to_timedelta
.. literalinclude:: ../src/xscen/CVs/frequency_to_xrfreq.json
:language: json
:caption: frequency_to_xrfreq
.. literalinclude:: ../src/xscen/CVs/infer_resolution.json
:language: json
:caption: infer_resolution
.. literalinclude:: ../src/xscen/CVs/resampling_methods.json
:language: json
:caption: resampling_methods
.. literalinclude:: ../src/xscen/CVs/variable_names.json
:language: json
:caption: variable_names
.. literalinclude:: ../src/xscen/CVs/xrfreq_to_frequency.json
:language: json
:caption: xrfreq_to_frequency
.. literalinclude:: ../src/xscen/CVs/xrfreq_to_timedelta.json
:language: json
:caption: xrfreq_to_timedelta
"""
),
)
def __read_CVs(cvfile): # noqa: N802
with cvfile.open("r") as f:
cv = json.load(f)
is_regex = cv.pop("is_regex", False)
doc = """Controlled vocabulary mapping from {name}.
The raw dictionary can be accessed by the dict attribute of this function.
Parameters
----------
key : str
The value to translate.{regex}
default : 'pass', 'error' or Any
If the key is not found in the mapping, default controls the behaviour.
- "error", a KeyError is raised (default).
- "pass", the key is returned.
- another value, that value is returned.
"""
def cvfunc(key, default="error"):
if is_regex:
for cin, cout in cv.items():
try:
if re.fullmatch(cin, key):
return cout
except TypeError:
pass
else:
if key in cv:
return cv[key]
if isinstance(default, str):
if default == "pass":
return key
if default == "error":
raise KeyError(key)
return default
cvfunc.__name__ = cvfile.stem
cvfunc.__doc__ = doc.format(
name=cvfile.stem.replace("_", " "),
regex=" The key will be matched using regex" if is_regex else "",
)
cvfunc.__dict__["dict"] = cv
cvfunc.__module__ = "xscen.CV"
return cvfunc
for cvfile in Path(__file__).parent.joinpath("CVs").glob("*.json"):
try:
CV.__dict__[cvfile.stem] = __read_CVs(cvfile)
# FIXME: This is a catch-all, but we should be more specific
except Exception as err: # noqa: BLE001
raise ValueError(f"Unable to process CV file: {cvfile}.") from err
[docs]
def change_units(ds: xr.Dataset, variables_and_units: dict) -> xr.Dataset:
"""
Change units of Datasets to non-CF units.
Parameters
----------
ds : xr.Dataset
Dataset to use.
variables_and_units : dict
Description of the variables and units to output.
Returns
-------
xr.Dataset
The dataset with updated units.
See Also
--------
xclim.core.units.convert_units_to : Convert units.
xclim.core.units.rate2amount : Convert a rate to an amount.
"""
with xr.set_options(keep_attrs=True):
for v in variables_and_units:
if v in ds:
if units.units2pint(ds[v]) != units.units2pint(variables_and_units[v]):
time_in_ds = units.units2pint(ds[v]).dimensionality.get("[time]")
time_in_out = units.units2pint(variables_and_units[v]).dimensionality.get("[time]")
if time_in_ds == time_in_out:
ds = ds.assign({v: units.convert_units_to(ds[v], variables_and_units[v])})
elif time_in_ds - time_in_out == 1:
# ds is an amount
ds = ds.assign({v: units.amount2rate(ds[v], out_units=variables_and_units[v])})
elif time_in_ds - time_in_out == -1:
# ds is a rate
ds = ds.assign({v: units.rate2amount(ds[v], out_units=variables_and_units[v])})
else:
raise ValueError(
f"No known transformation between {ds[v].units} and {variables_and_units[v]} (temporal dimensionality mismatch)."
)
# update unit name if physical units are equal but not their name (ex. degC vs °C)
if (units.units2pint(ds[v]) == units.units2pint(variables_and_units[v])) and (ds[v].units != variables_and_units[v]):
ds = ds.assign({v: ds[v].assign_attrs(units=variables_and_units[v])})
return ds
def _convert_units_to_infer(source, target):
return units.convert_units_to(source, target, context="infer")
[docs]
@contextmanager
def xclim_convert_units_to():
"""
Patch xsdba with xclim's units converter.
Yields
------
None
In this context, ``xsdba.units.convert_units_to`` is replaced with
``xclim.core.units.convert_units_to`` with `context="infer"` activated.
"""
original_function = xsdba.units._convert_units_to
new_function = _convert_units_to_infer
try:
xsdba.units._convert_units_to = new_function
yield
finally:
xsdba.units._convert_units_to = original_function
[docs]
def clean_up( # noqa: C901
ds: xr.Dataset,
*,
variables_and_units: dict | None = None,
fill_nan_ds: xr.Dataset | None = None,
convert_calendar_kwargs: dict | None = None,
missing_by_var: dict | None = None,
maybe_unstack_dict: dict | None = None,
round_var: dict | None = None,
clip_var: dict | None = None,
common_attrs_only: None | (dict | list[xr.Dataset | str | os.PathLike]) = None,
common_attrs_open_kwargs: dict | None = None,
attrs_to_remove: dict | None = None,
remove_all_attrs_except: dict | None = None,
add_attrs: dict | None = None,
change_attr_prefix: str | dict | None = None,
to_level: str | None = None,
) -> xr.Dataset:
"""
Clean up of the dataset.
It can:
- convert to the right units using xscen.utils.change_units
- convert the calendar and interpolate over missing dates
- call the xscen.utils.maybe_unstack function
- round variables
- clip variables
- remove a list of attributes
- remove everything but a list of attributes
- add attributes
- change the prefix of the catalog attrs
in that order.
Parameters
----------
ds : xr.Dataset
Input dataset to clean up.
variables_and_units : dict, optional
Dictionary of variable to convert. e.g. {'tasmax': 'degC', 'pr': 'mm d-1'}.
fill_nan_ds : xarray.Dataset, optional
Dataset with the same spatial dimensions as ds and the same cat:domain attrs.
Will fill NaNs in ds with the values from the variables of the same name in this Dataset.
convert_calendar_kwargs : dict, optional
Dictionary of arguments to feed to xarray.Dataset.convert_calendar. This will be the same for all variables.
If missing_by_vars is given, it will override the 'missing' argument given here.
Eg. {'calendar': 'standard', 'align_on': 'random'}.
missing_by_var : dict, optional
Dictionary where the keys are the variables and the values are the argument to feed the `missing`
parameters of xarray.Dataset.convert_calendar for the given variable with the
`convert_calendar_kwargs`. When the value of an entry is 'interpolate', the missing values will be filled
with NaNs, then linearly interpolated over time.
maybe_unstack_dict : dict, optional
Dictionary to pass to xscen.common.maybe_unstack function.
The format should be: {'coords': path_to_coord_file, 'rechunk': {'time': -1 }, 'stack_drop_nans': True}.
round_var : dict, optional
Dictionary where the keys are the variables of the dataset and the values are the number of
decimal places to round to.
clip_var : dict, optional
Dictionary where the keys are the variables of the dataset and the values are the arguments to give ``.clip()``.
common_attrs_only : dict, list of datasets, or list of paths, optional
Dictionary of datasets or list of datasets, or path to NetCDF or Zarr files.
Keeps only the global attributes that are the same for all datasets and generates a new id.
common_attrs_open_kwargs : dict, optional
Dictionary of arguments for xarray.open_dataset(). Used with common_attrs_only if given paths.
attrs_to_remove : dict, optional
Dictionary where the keys are the variables and the values are a list of the attrs that should be removed.
The match is done using re.fullmatch, so the strings can be regex patterns but don't need to contain '^' or '$'.
For global attrs, use the key 'global'.
e.g. {'global': ['unnecessary note', 'cell.*'], 'tasmax': 'old_name'}.
remove_all_attrs_except : dict, optional
Dictionary where the keys are the variables and the values are a list of the attrs that should NOT be removed.
The match is done using re.fullmatch, so the strings can be regex patterns but don't need to contain '^' or '$'.
All other attributes will be deleted. For global attrs, use the key 'global'.
e.g. {'global': ['necessary note', '^cat:'], 'tasmax': 'new_name'}.
add_attrs : dict, optional
Dictionary where the keys are the variables and the values are a another dictionary of attributes.
For global attrs, use the key 'global'.
e.g. {'global': {'title': 'amazing new dataset'}, 'tasmax': {'note': 'important info about tasmax'}}.
change_attr_prefix : str or dict, optional
If a string, replace "cat:" in the catalog global attributes by this new string.
If a dictionary, the key is the old prefix and the value is the new prefix.
to_level : str, optional
The processing level to assign to the output.
Returns
-------
xr.Dataset
Cleaned up dataset.
See Also
--------
xarray.Dataset.convert_calendar : Calendar conversion.
xarray.DataArray.round : Rounding of data array values.
xarray.DataArray.clip : Clipping of data array values between a given min and max.
"""
ds = ds.copy()
if variables_and_units:
msg = f"Converting units: {variables_and_units}"
logger.info(msg)
ds = change_units(ds=ds, variables_and_units=variables_and_units)
if fill_nan_ds is not None:
# check if any non-time dimension are different
if any([ds.sizes[d] != fill_nan_ds.sizes[d] for d in ds.dims if d != "time"]) or (
ds.attrs.get("cat:domain", "foo") != fill_nan_ds.attrs.get("cat:domain", "foo")
):
raise ValueError(
"The non-time dimensions or the cat:domain attribute of the simulation"
" and reference datasets do not match. "
"Cannot fill missing values."
)
for var in ds.data_vars:
if var in fill_nan_ds:
ds[var] = ds[var].combine_first(fill_nan_ds[var])
new_history = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Filled missing values using {fill_nan_ds.get('cat:id', '')} dataset."
history = f"{new_history}\n{ds[var].attrs['history']}" if "history" in ds[var].attrs else new_history
ds[var].attrs["history"] = history
# convert calendar
if convert_calendar_kwargs:
vars_with_no_time = [v for v in ds.data_vars if "time" not in ds[v].dims]
# create mask of grid point that should always be nan
ocean = ds.isnull().all("time")
# if missing_by_var exist make sure missing data are added to time axis
if missing_by_var:
if not all(k in missing_by_var.keys() for k in ds.data_vars):
raise ValueError("All variables must be in 'missing_by_var' if using this option.")
convert_calendar_kwargs["missing"] = -9999
# make default `align_on`='`random` when the initial calendar is 360day
if any(cal == "360_day" for cal in [ds.time.dt.calendar, convert_calendar_kwargs["calendar"]]) and "align_on" not in convert_calendar_kwargs:
convert_calendar_kwargs["align_on"] = "random"
msg = f"Converting calendar with {convert_calendar_kwargs}."
logger.info(msg)
ds = ds.convert_calendar(**convert_calendar_kwargs).where(~ocean)
# FIXME: Fix for xarray <= 2025.04.0: https://github.com/pydata/xarray/issues/10266
for vv in vars_with_no_time:
if "time" in ds[vv].dims:
ds[vv] = ds[vv].isel(time=0).drop_vars("time")
# convert each variable individually
if missing_by_var:
# remove 'missing' argument to be replaced by `missing_by_var`
del convert_calendar_kwargs["missing"]
for var, missing in missing_by_var.items():
msg = f"Filling missing {var} with {missing}"
logging.info(msg)
if missing == "interpolate":
ds_with_nan = ds[var].where(ds[var] != -9999)
converted_var = ds_with_nan.chunk({"time": -1}).interpolate_na("time", method="linear")
else:
converted_var = ds[var].where(ds[var] != -9999, other=missing)
ds = ds.assign({var: converted_var})
# unstack nans
if maybe_unstack_dict:
ds = maybe_unstack(ds, **maybe_unstack_dict)
if round_var:
for var, n in round_var.items():
ds[var] = ds[var].round(n)
new_history = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Rounded '{var}' to {n} decimals."
history = f"{new_history}\n{ds[var].attrs['history']}" if "history" in ds[var].attrs else new_history
ds[var].attrs["history"] = history
if clip_var:
for var, c in clip_var.items():
ds[var] = ds[var].clip(*c, keep_attrs=True)
new_history = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Clipped '{var}' to {c}."
history = f"{new_history}\n{ds[var].attrs['history']}" if "history" in ds[var].attrs else new_history
ds[var].attrs["history"] = history
if common_attrs_only:
from .catalog import generate_id
common_attrs_open_kwargs = common_attrs_open_kwargs or {}
if isinstance(common_attrs_only, dict):
common_attrs_only = list(common_attrs_only.values())
for i in range(len(common_attrs_only)):
if isinstance(common_attrs_only[i], str | os.PathLike):
dataset = xr.open_dataset(common_attrs_only[i], **common_attrs_open_kwargs)
else:
dataset = common_attrs_only[i]
attributes = ds.attrs.copy()
for a_key, a_val in attributes.items():
if (a_key not in dataset.attrs) or (a_key in ["cat:date_start", "cat:date_end"]) or (a_val != dataset.attrs[a_key]):
del ds.attrs[a_key]
# generate a new id
try:
ds.attrs["cat:id"] = generate_id(ds).iloc[0]
except IndexError as err:
msg = f"Unable to generate a new id for the dataset. Got {err}."
logger.warning(msg)
if to_level:
ds.attrs["cat:processing_level"] = to_level
# remove attrs
if attrs_to_remove:
for var, list_of_attrs in attrs_to_remove.items():
obj = ds if var == "global" else ds[var]
to_remove = list(chain.from_iterable([list(filter(re.compile(attr).fullmatch, list(obj.attrs.keys()))) for attr in list_of_attrs]))
for attr in to_remove:
del obj.attrs[attr]
# delete all attrs, but the ones in the list
if remove_all_attrs_except:
for var, list_of_attrs in remove_all_attrs_except.items():
obj = ds if var == "global" else ds[var]
to_keep = list(chain.from_iterable([list(filter(re.compile(attr).fullmatch, list(obj.attrs.keys()))) for attr in list_of_attrs]))
to_remove = list(set(obj.attrs.keys()).difference(to_keep))
for attr in to_remove:
del obj.attrs[attr]
if add_attrs:
for var, attrs in add_attrs.items():
obj = ds if var == "global" else ds[var]
for attrname, attrtmpl in attrs.items():
obj.attrs[attrname] = attrtmpl
if change_attr_prefix:
if isinstance(change_attr_prefix, str):
change_attr_prefix = {"cat:": change_attr_prefix}
# Make sure that the prefixes are in the right format
chg_attr_prefix = {}
for old_prefix, new_prefix in change_attr_prefix.items():
if not old_prefix.endswith(":"):
old_prefix += ":"
if not new_prefix.endswith(":"):
new_prefix += ":"
chg_attr_prefix[old_prefix] = new_prefix
# Change the prefixes, but keep the order of the keys
attrs = {}
for ds_attr in list(ds.attrs.keys()):
changed = False
for old_prefix, new_prefix in chg_attr_prefix.items():
if ds_attr.startswith(old_prefix):
new_name = ds_attr.replace(old_prefix, new_prefix)
attrs[new_name] = ds.attrs[ds_attr]
changed = True
if not changed:
attrs[ds_attr] = ds.attrs[ds_attr]
ds.attrs = attrs
return ds
def _unstack_doy(ds: xr.Dataset, new_dim: str | None):
"""Unstack a daily timeseries into dayofyear and year."""
ds = (
ds.assign_coords(original_time=ds.time)
.assign_coords(
xr.Coordinates.from_pandas_multiindex(
pd.MultiIndex.from_arrays(
(ds.time.dt.year.values, ds.time.dt.dayofyear.values),
names=("year", new_dim or "dayofyear"),
),
"time",
)
)
.unstack("time")
)
return ds.rename(year="time").assign_coords(time=pd.to_datetime({"year": ds.year, "month": 1, "day": 1}))
[docs]
def unstack_dates( # noqa: C901
ds: xr.Dataset,
seasons: dict[int, str] | None = None,
new_dim: str | None = None,
winter_starts_year: bool = False,
year_start_month: int = 1,
):
"""
Unstack a multi-season timeseries into a yearly axis and a season one.
Parameters
----------
ds : xr.Dataset or DataArray
The xarray object with a "time" coordinate.
Only supports daily or coarser frequencies (excluding weekly).
The time axis must be complete and regular (`xr.infer_freq(ds.time)` doesn't fail).
seasons : dict, optional
A dictionary from month number (as int) to a season name.
If not given, it is guessed from the time coordinate frequency.
Not used with daily data. See notes.
new_dim : str, optional
The name of the new dimension.
If None, the name is inferred from the frequency of the time axis.
See notes.
winter_starts_year : bool, optional
Deprecated. Setting to True is the old way of passing `year_start_month=12`.
year_start_month : int, optional
Change on which month the year starts. If greater than 1, seasons/months between that and 12 (included)
are associated with the following year. Usually this is used as 12 with seasonal data, so that the DJF season is
associated with the next year, i.e. DJF made from [Dec 1980, Jan 1981, and Feb 1981] will be associated with the year 1981, not 1980.
Not used with daily data. Replaces `winter_starts_year`.
Returns
-------
xr.Dataset or DataArray
Same as ds but the time axis is now yearly (YS-JAN) and the seasons are along the new dimension.
The previous time dimension is left as the new 2D `original_time`.
Notes
-----
When `seasons` is None, the inferred frequency determines the new coordinate:
- For MS, the coordinates are the month abbreviations in english (JAN, FEB, etc.)
- For ?QS-? and other ?MS frequencies, the coordinates are the initials of the months in each season. Ex: QS -> DJF, MAM, JJA, SON.
- For YS or YS-JAN, the new coordinate has a single value of "annual".
- For ?YS-? frequencies, the new coordinate has a single value of "annual-{anchor}". Ex: YS-JUL -> "annual-JUL".
When `new_dim` is None, the new dimension name is inferred from the frequency:
- For ?YS, ?QS frequencies or ?MS with mult > 1, the new dimension is "season".
- For MS, the new dimension is "month".
"""
if winter_starts_year:
warnings.warn(
"Since xscen 0.14, `winter_starts_year=True` has been deprecated in favor of `year_start_month=12`."
" It will be removed in a future release.",
FutureWarning,
stacklevel=1,
)
year_start_month = 12
# Get some info about the time axis
freq = xr.infer_freq(ds.time)
if freq is None:
raise ValueError(
"The data must have a clean time coordinate. If you know the "
"data's frequency, please pass `ds.resample(time=freq).first()` "
"to pad missing dates and reset the time coordinate."
)
first, last = ds.indexes["time"][[0, -1]]
use_cftime = xr.coding.times.contains_cftime_datetimes(ds.time.variable)
calendar = ds.time.dt.calendar
mult, base, isstart, anchor = parse_offset(freq)
if base == "D": # fast-track for daily
return _unstack_doy(ds, new_dim)
if base not in "YAQM":
raise ValueError(f"Only daily frequencies or coarser are supported. Got: {freq}.")
if new_dim is None:
if base == "M" and mult == 1:
new_dim = "month"
else:
new_dim = "season"
if base in "YA":
if seasons:
seaname = f"{seasons[first.month]}"
elif anchor == "JAN":
seaname = "annual"
else:
seaname = f"annual-{anchor}"
if mult > 1:
seaname = f"{mult}{seaname}"
# Fast track for annual, if nothing more needs to be done.
if year_start_month == 1:
dso = ds.expand_dims({new_dim: [seaname]}).assign_coords(original_time=ds.time.expand_dims({new_dim: [seaname]}))
dso["time"] = xr.date_range(
f"{first.year}-01-01",
f"{last.year}-01-01",
freq=f"{mult}YS",
calendar=calendar,
use_cftime=use_cftime,
)
return dso
else:
seasons = seasons or {}
seasons.update({first.month: seaname})
if base == "M" and 12 % mult != 0:
raise ValueError(f"Only periods that divide the year evenly are supported. Got {freq}.")
# Guess the new season coordinate
if seasons is None:
if base == "Q" or (base == "M" and mult > 1):
# Labels are the month initials
months = np.array(list("JFMAMJJASOND"))
n = mult * {"M": 1, "Q": 3}[base]
seasons = {m: "".join(months[np.array(range(m - 1, m + n - 1)) % 12]) for m in np.unique(ds.time.dt.month)}
else: # M or MS
seasons = xr.coding.cftime_offsets._MONTH_ABBREVIATIONS
else:
# Only keep the entries for the months in the data
seasons = {m: seasons[m] for m in np.unique(ds.time.dt.month)}
if year_start_month > 1:
# Sort season names from the beginning of the year
seas_list = [seasons[m] for m in sorted(seasons.keys()) if m >= year_start_month] + [
seasons[m] for m in sorted(seasons.keys()) if m < year_start_month
]
# The year associated with each timestamp
years = ds.time.dt.year + xr.where(ds.time.dt.month >= year_start_month, 1, 0)
else:
# The ordered season names
seas_list = [seasons[month] for month in sorted(seasons.keys())]
years = ds.time.dt.year
# The goal here is to use `reshape()` instead of `unstack` to limit the number of dask operations.
# Thus, the time axis must be properly constructed so that reshapes fits the final size.
# We pad on both sides to ensure full years
pad_left = seas_list.index(seasons[first.month])
pad_right = len(seas_list) - (seas_list.index(seasons[last.month]) + 1)
dsp = ds.pad(time=(pad_left, pad_right)) # pad with NaN
# Similarly pad our "group labels".
years = years.pad(time=(pad_left, pad_right), constant_values=(years[0], years[-1]))
# And pad the original time
# a bit more complicated, we use the negative freq feature of date_range to get valid dates before start
_before = xr.date_range(
ds.indexes["time"][0],
freq=f"-1{freq}" if mult == 1 else f"-{freq}",
periods=pad_left + 1,
inclusive="right",
calendar=ds.time.dt.calendar,
use_cftime=ds.time.dtype == "O",
)[::-1]
_after = xr.date_range(
ds.indexes["time"][-1], freq=freq, periods=pad_right + 1, inclusive="right", calendar=ds.time.dt.calendar, use_cftime=ds.time.dtype == "O"
)
dsp = dsp.assign_coords(time=_before.append(ds.indexes["time"]).append(_after))
# New coords
new_time = xr.date_range( # New time axis (YS)
f"{years[0].item()}-01-01",
f"{years[-1].item()}-01-01",
freq="YS",
calendar=calendar,
use_cftime=use_cftime,
)
def _reshape_da(da):
# Replace (A,'time',B) by (A,'time', 'season',B) in both the new shape and the new dims
new_dims = list(chain.from_iterable([d] if d != "time" else ["time", new_dim] for d in da.dims))
new_shape = [len(new_coords[d]) for d in new_dims]
# Use dask or numpy's algo.
if uses_dask(da):
# This is where it happens. Flox will minimally rechunk
# so the reshape operation can be performed blockwise
da = flox.xarray.rechunk_for_blockwise(da, "time", years)
return xr.DataArray(da.data.reshape(new_shape), dims=new_dims)
new_coords = dict(ds.coords)
new_coords.update({"time": new_time, new_dim: seas_list})
# put other coordinates that depend on time in the new shape
for coord in new_coords:
if (coord not in ["time", new_dim]) and ("time" in ds[coord].dims):
new_coords[coord] = _reshape_da(dsp[coord])
new_coords["original_time"] = _reshape_da(dsp.time)
if isinstance(ds, xr.Dataset):
dso = dsp.map(_reshape_da, keep_attrs=True)
else:
dso = _reshape_da(dsp)
return dso.assign_coords(**new_coords)
[docs]
def stack_dates(ds: xr.Dataset) -> xr.Dataset:
"""
Revert the effect of :py:func:`unstack_dates`.
Parameters
----------
ds : xr.Dataset
The dataset to unstack. Must contain the `original_time` coordinate added by `unstack_dates`.
Returns
-------
xr.Dataset
The dataset with the time axis restored to its original form.
"""
time_dims = list(ds.original_time.dims)
season_dim = list(set(time_dims) - {"time"})
out = (
ds.stack(real_time=time_dims)
.swap_dims(real_time="original_time")
.drop_vars(["real_time", *time_dims])
.rename(original_time="time")
.transpose(*[d for d in ds.dims if d not in season_dim])
)
return out.where(out.time.notnull(), drop=True)
[docs]
def ensure_correct_time(ds: xr.Dataset, xrfreq: str) -> xr.Dataset:
"""
Ensure a dataset has the correct time coordinate, as expected for the given frequency.
Parameters
----------
ds : xr.Dataset
The dataset to check and modify if necessary. Must have a "time" coordinate.
xrfreq : str
The expected frequency of the time coordinate, in xarray offset alias format (e.g. "D", "MS", etc.).
Returns
-------
xr.Dataset
The dataset with a corrected time coordinate, if necessary.
Notes
-----
Daily or finer datasets are "floored" even if `xr.infer_freq` succeeds.
Errors are raised if the number of data points per period is not 1.
The dataset is modified in-place, but returned nonetheless.
"""
# Check if we got the expected freq (skip for too short timeseries)
inffreq = xr.infer_freq(ds.time) if ds.time.size > 2 else None
if inffreq == xrfreq:
# Even when the freq is correct, we ensure the correct "anchor" for daily and finer
# also done in preprocess
if xrfreq in "DHTMUL":
ds["time"] = ds.time.dt.floor(xrfreq)
else:
# We can't infer it, there might be a problem
counts = ds.time.resample(time=xrfreq).count()
if (counts > 1).any().item():
raise ValueError(f"Dataset is labelled as having a sampling frequency of {xrfreq}, but some periods have more than one data point.")
if (counts.isnull() | (counts == 0)).any().item():
raise ValueError("The resampling count contains NaNs or 0s. There might be some missing data.")
ds["time"] = counts.time
return ds
[docs]
def standardize_periods(
periods: list[str | pd.Timestamp] | list[list[str | pd.Timestamp]] | None,
multiple: bool = True,
end_of_periods: bool = True,
out_dtype: str = "str",
) -> list[str] | list[list[str]] | None:
"""
Reformat the input to a list of strings or Timestamps, ['start', 'end'], or a list of such lists. Does not modify in-place.
Parameters
----------
periods : list of str or pd.Timestamp, or list of lists of str or pd.Timestamp, optional
The period(s) to standardize. If None, return None.
multiple : bool
If True, return a list of periods, otherwise return a single period.
end_of_periods : bool or str
If 'YE' or 'ME', the returned date will be the end of the year or month that contains the received date.
If True (default), standardizes yearly and monthly periods to end on the last second of the last day of the year/month.
This parameter is only used for str periods that do not specify the month/day.
out_dtype : str
Choices are 'datetime', 'period' or 'str'. Defaults to 'str', which will only output the year.
Returns
-------
list of str or list of lists of str, or None
The standardized period(s).
"""
if periods is None:
return periods
periods = deepcopy(periods)
if not isinstance(periods[0], list):
periods = [periods]
for i in range(len(periods)):
if len(periods[i]) != 2:
raise ValueError("Each instance of 'periods' should be comprised of two elements: [start, end].")
period = periods[i]
if isinstance(period[0], int) or isinstance(period[0], str):
period[0] = date_parser(str(period[0]), out_dtype="datetime")
if isinstance(period[1], int) or isinstance(period[1], str):
period[1] = date_parser(str(period[1]), out_dtype="datetime", end_of_period=end_of_periods)
if period[0] > period[1]:
raise ValueError(f"'periods' should be in chronological order, received {periods[i]}.")
# TODO: allow more than year in periods for out_dtype = str
periods[i] = [
date_parser(period[0], out_dtype=out_dtype, strtime_format="%Y"),
date_parser(period[1], out_dtype=out_dtype, strtime_format="%Y"),
]
if multiple:
return periods
else:
if len(periods) > 1:
raise ValueError(f"'period' should be a single instance of [start, end], received {len(periods)}.")
return periods[0]
def season_sort_key(idx: pd.Index, name: str | None = None) -> pd.Index:
"""
Get a proper sort key for a "season" or "month" index to avoid alphabetical sorting.
If any of the values in the index is not recognized as a 3-letter
season code or a 3-letter month abbreviation, the operation is
aborted and the index is returned untouched.
DJF is the first season of the year.
Parameters
----------
idx : pd.Index
Any array that implements a `map` method.
If name is "month", index elements are expected to be 3-letter month abbreviations, uppercase (JAN, FEB, etc).
If name is "season", index elements are expected to be 3-letter season abbreviations, uppercase (DJF, AMJ, OND, etc.)
If anything else, the index is returned untouched.
name : str, optional
The index name. By default, the `name` attribute of the index is used, if present.
Returns
-------
pd.Index
Integer sort key for months and seasons, the input index untouched otherwise.
"""
try:
if (name or getattr(idx, "name", None)) == "season":
m = "DJFMAMJJASONDJ"
return idx.map(m.index)
if (name or getattr(idx, "name", None)) == "month":
m = list(xr.coding.cftime_offsets._MONTH_ABBREVIATIONS.values())
return idx.map(m.index)
except (TypeError, ValueError) as err:
# ValueError if string not in seasons, or value not in months
# TypeError if season element was not a string.
logging.error(err)
return idx
[docs]
def xrfreq_to_timedelta(freq: str) -> pd.Timedelta:
"""
Approximate the length of a period based on its frequency offset.
Parameters
----------
freq : str
Frequency string.
Returns
-------
pd.Timedelta
Approximate length of the period.
"""
N, B, _, _ = parse_offset(freq)
return N * pd.Timedelta(CV.xrfreq_to_timedelta(B, "NaT"))
[docs]
def ensure_new_xrfreq(freq: str) -> str: # noqa: C901
"""
Convert the frequency string to the newer syntax (pandas >= 2.2), if needed.
Parameters
----------
freq : str
Frequency string.
Returns
-------
str
New format frequency string.
"""
# Copied from xarray xr.coding.cftime_offsets._legacy_to_new_freq
# https://github.com/pydata/xarray/pull/8627/files
if not isinstance(freq, str):
# For when freq is NaN or None in a catalog
return freq
try:
freq_as_offset = cfoff.to_offset(freq, warn=False)
except ValueError:
# freq may be valid in pandas but not in xarray
return freq
if isinstance(freq_as_offset, cfoff.MonthEnd) and "ME" not in freq:
freq = freq.replace("M", "ME")
elif isinstance(freq_as_offset, cfoff.QuarterEnd) and "QE" not in freq:
freq = freq.replace("Q", "QE")
elif isinstance(freq_as_offset, cfoff.YearBegin) and "YS" not in freq:
freq = freq.replace("AS", "YS")
elif isinstance(freq_as_offset, cfoff.YearEnd):
if "A-" in freq:
# Check for and replace "A-" instead of just "A" to prevent
# corrupting anchored offsets that contain "Y" in the month
# abbreviation, e.g. "A-MAY" -> "YE-MAY".
freq = freq.replace("A-", "YE-")
elif "Y-" in freq:
freq = freq.replace("Y-", "YE-")
elif freq.endswith("A"):
# the "A-MAY" case is already handled above
freq = freq.replace("A", "YE")
elif "YE" not in freq and freq.endswith("Y"):
# the "Y-MAY" case is already handled above
freq = freq.replace("Y", "YE")
elif isinstance(freq_as_offset, cfoff.Hour):
freq = freq.replace("H", "h")
elif isinstance(freq_as_offset, cfoff.Minute):
freq = freq.replace("T", "min")
elif isinstance(freq_as_offset, cfoff.Second):
freq = freq.replace("S", "s")
elif isinstance(freq_as_offset, cfoff.Millisecond):
freq = freq.replace("L", "ms")
elif isinstance(freq_as_offset, cfoff.Microsecond):
freq = freq.replace("U", "us")
return freq
def _xarray_defaults(**kwargs):
"""Translate from xscen's extract names to intake-esm names and put better defaults."""
if "xr_open_kwargs" in kwargs:
kwargs["xarray_open_kwargs"] = kwargs.pop("xr_open_kwargs")
if "xr_combine_kwargs" in kwargs:
kwargs["xarray_combine_by_coords_kwargs"] = kwargs.pop("xr_combine_kwargs")
kwargs.setdefault("xarray_open_kwargs", {}).setdefault("chunks", {})
kwargs.setdefault("xarray_combine_by_coords_kwargs", {}).setdefault("data_vars", "minimal")
return kwargs
def rechunk_for_resample(obj: xr.DataArray | xr.Dataset, **resample_kwargs) -> xr.DataArray | xr.Dataset:
r"""
Rechunk object for resampling.
Parameters
----------
obj : xr.DataArray or xr.Dataset
Object.
**resample_kwargs : dict
Resampling keyword arguments.
Returns
-------
xr.DataArray or xr.Dataset
Rechunked object for resampling.
"""
if not uses_dask(obj):
return obj
res = obj.resample(**resample_kwargs)
return flox.xarray.rechunk_for_blockwise(obj, res._dim, res._codes)