"""Functions to train and adjust a dataset using a bias-adjustment algorithm."""
import ast
import logging
import warnings
from copy import deepcopy
import xarray as xr
import xclim as xc
import xsdba
from xsdba.processing import from_additive_space, to_additive_space
from .catutils import parse_from_ds
from .config import parse_config
from .utils import minimum_calendar, standardize_periods, xclim_convert_units_to
logger = logging.getLogger(__name__)
xsdba.set_options(extra_output=False)
__all__ = [
"adjust",
"train",
]
def _add_preprocessing_attr(scen, train_kwargs):
fake_ref = xr.DataArray(name="ref")
fake_hist = xr.DataArray(name="hist")
scen.attrs["bias_adjustment"] += f" with xsdba_train_args: {train_kwargs['xsdba_train_args']}"
preproc = []
if train_kwargs["jitter_under"] is not None:
preproc.append(xc.core.formatting.gen_call_string("jitter_under_thresh", fake_ref, fake_hist, train_kwargs["jitter_under"]))
if train_kwargs["jitter_over"] is not None:
preproc.append(xc.core.formatting.gen_call_string("jitter_over_thresh", fake_ref, fake_hist, train_kwargs["jitter_over"]))
if preproc:
scen.attrs["bias_adjustment"] += ", ref and hist were prepared with " + " and ".join(preproc)
return scen
[docs]
@parse_config
def train(
dref: xr.Dataset,
dhist: xr.Dataset,
var: str | list[str],
period: list[str],
*,
method: str = "DetrendedQuantileMapping",
group: xsdba.Grouper | str | dict | bool | None = None,
xsdba_train_args: dict | None = None,
xclim_train_args: dict | None = None,
maximal_calendar: str = "noleap",
jitter_under: dict | None = None,
jitter_over: dict | None = None,
align_on: str | None = "year",
additive_space: dict | None = None,
) -> xr.Dataset:
"""
Train a bias-adjustment.
Parameters
----------
dref : xr.Dataset
The target timeseries, on the reference period.
dhist : xr.Dataset
The timeseries to adjust, on the reference period.
var : str or list of str
Variable(s) on which to do the adjustment.
period : list of str
[start, end] of the reference period.
method : str
Name of the `xsdba.TrainAdjust` method of xclim.
group : str or xsdba.Grouper or dict or bool, optional
Grouping information. If a string, it is interpreted as a grouper on the time dimension.
If a dict, it is passed to `xsdba.Grouper.from_kwargs`.
Defaults to {"group": "time.dayofyear", "window": 31}.
If False, this argument will be skipped and never passed to the adjustment.
xsdba_train_args : dict, optional
Dict of arguments to pass to the `.train` of the adjustment object.
xclim_train_args : dict, optional
Dict of arguments to pass to the `.train` of the adjustment object.
A warning will be emitted stating that this a legacy argument replaced with `xsdba_train_args`.
maximal_calendar : str
Maximal calendar dhist can be. The hierarchy: 360_day < noleap < standard < all_leap.
If dhist's calendar is higher than maximal calendar, it will be converted to the maximal calendar.
jitter_under : dict, optional
If given, a dictionary of args to pass to `jitter_under_thresh`.
jitter_over : dict, optional
If given, a dictionary of args to pass to `jitter_over_thresh`.
align_on : str, optional
`align_on` argument for the function `xr.DataArray.convert_calendar`.
additive_space : dict, optional
A dictionary of variables and their arguments to convert them to additive space.
The transformation will be applied to both `dref` and `dhist` datasets,
as well as `dsim` and `dref` in `adjust`.
Finally, `from_additive_space` will be called on the output of `adjust`.
The keys are the variable names, and the values are the arguments for `to_additive_space`.
If given, `kind` in `xsdba_train_args` must be '+'.
Returns
-------
xr.Dataset
Trained algorithm's data.
See Also
--------
xsdba.adjustment.DetrendedQuantileMapping : Detrended Quantile Mapping bias-adjustment.
xsdba.adjustment.ExtremeValues : Adjustment correction for extreme values.
xsdba.processing.to_additive_space : Transform a non-additive variable into an additive space by the means of a log or logit transformation.
xsdba.processing.from_additive_space : Transform back to the physical space a variable that was transformed with to_additive_space.
"""
if xclim_train_args is not None:
warnings.warn(
"`xclim_train_args` will be deprecated and replaced by `xsdba_train_args`.",
FutureWarning,
stacklevel=2,
)
if xsdba_train_args is not None:
warnings.warn(
"`xclim_train_args` and `xsdba_train_args` were both given, but correspond to the same option. `xsdba_train_args` will be kept",
stacklevel=2,
)
else:
xsdba_train_args = deepcopy(xclim_train_args)
if isinstance(var, str):
var = [var]
# transforms
additive_space = additive_space or {}
if additive_space:
for add_var, add_args in additive_space.items():
dref[add_var] = to_additive_space(dref[add_var], **add_args)
dhist[add_var] = to_additive_space(dhist[add_var], **add_args)
if "kind" in xsdba_train_args and xsdba_train_args["kind"] != "+":
warnings.warn("`additive_space` was given, but `kind` in `xsdba_train_args` is not '+'.", stacklevel=2)
if len(var) == 1:
ref = dref[var[0]]
hist = dhist[var[0]]
else:
# Eventually, we can change ["MBCn"] and add more supported multivariate methods
if method not in ["MBCn"]:
raise ValueError(f"Multiple variables were given: {var}, but this treatment only works with a multivariate method,got {method}.")
ref = xsdba.stack_variables(dref[var])
hist = xsdba.stack_variables(dhist[var])
group = group if group is not None else {"group": "time.dayofyear", "window": 31}
xsdba_train_args = xsdba_train_args or {}
xsdba_train_args_copy = deepcopy(xsdba_train_args) # for train args
if method == "DetrendedQuantileMapping":
xsdba_train_args.setdefault("nquantiles", 15)
# cut out the right period
period = standardize_periods(period, multiple=False)
hist = hist.sel(time=slice(*period))
ref = ref.sel(time=slice(*period))
# convert calendar if necessary
simcal = hist.time.dt.calendar
refcal = ref.time.dt.calendar
mincal = minimum_calendar(simcal, maximal_calendar)
if simcal != mincal:
hist = hist.convert_calendar(mincal, align_on=align_on)
if refcal != mincal:
ref = ref.convert_calendar(mincal, align_on=align_on)
if isinstance(group, dict):
# So we can specify window and add_dims in yaml.
group = xsdba.Grouper.from_kwargs(**group)["group"]
elif isinstance(group, str):
group = xsdba.Grouper(group)
if group is not False:
if method != "MBCn":
xsdba_train_args["group"] = group
else:
xsdba_train_args.setdefault("base_kws", {})
xsdba_train_args["base_kws"]["group"] = group
with xclim_convert_units_to():
if jitter_over is not None:
ref = xsdba.processing.jitter_over_thresh(ref, **jitter_over)
hist = xsdba.processing.jitter_over_thresh(hist, **jitter_over)
if jitter_under is not None:
ref = xsdba.processing.jitter_under_thresh(ref, **jitter_under)
hist = xsdba.processing.jitter_under_thresh(hist, **jitter_under)
ADJ = getattr(xsdba.adjustment, method).train(ref, hist, **xsdba_train_args)
ds = ADJ.ds
# Arguments that need to be transferred to the adjust() function
xsdba_train_args.pop("group", None)
ds.attrs["train_params"] = {
"var": var,
"maximal_calendar": maximal_calendar,
"xsdba_train_args": xsdba_train_args_copy,
"jitter_under": jitter_under,
"jitter_over": jitter_over,
"period": period,
"additive_space": additive_space,
}
# attrs that are needed to open with .to_dataset_dict()
for a in ["cat:xrfreq", "cat:domain", "cat:id"]:
ds.attrs[a] = dhist.attrs[a] if a in dhist.attrs else None
ds.attrs["cat:processing_level"] = f"training_{'_'.join(var)}"
ds.attrs["cat:bias_adjust_reference"] = dref.attrs.get("cat:source", "unknown")
return ds
[docs]
@parse_config
def adjust( # noqa: C901
dtrain: xr.Dataset,
dsim: xr.Dataset,
periods: list[str] | list[list[str]],
*,
stack_periods: dict | None = None,
dref: xr.Dataset | None = None,
xsdba_adjust_args: dict | None = None,
xclim_adjust_args: dict | None = None,
to_level: str = "biasadjusted",
bias_adjust_institution: str | None = None,
bias_adjust_project: str | None = None,
bias_adjust_reference: str | None = None,
align_on: str | None = "year",
) -> xr.Dataset:
"""
Adjust a simulation.
Parameters
----------
dtrain : xr.Dataset
A trained algorithm's dataset, as returned by `train`.
dsim : xr.Dataset
Simulated timeseries, projected period.
periods : list of str or list of lists of str
Either [start, end] or list of [start, end] of the simulation periods to be adjusted (one at a time).
stack_periods : dict, optional
Dictionary of arguments to pass to `xsdba.stack_periods` before adjustment. `xsdba.unstack_periods` will be called after.
If given, the 'periods' argument must contain a single period, which will be subsetted before calling `xsdba.stack_periods`.
dref : xr.Dataset, optional
Reference timeseries, needed only for certain methods.
xsdba_adjust_args : dict, optional
Dict of arguments to pass to the `.adjust` of the adjustment object.
xclim_adjust_args : dict, optional
Dict of arguments to pass to the `.adjust` of the adjustment object
A warning will be emitted stating that this a legacy argument replaced with `xclim_train_args`.
to_level : str
The processing level to assign to the output.
Defaults to 'biasadjusted'.
bias_adjust_institution : str, optional
The institution to assign to the output.
bias_adjust_project : str, optional
The project to assign to the output.
bias_adjust_reference : str, optional
FIXME: What is this exactly? TBD.
align_on : str, optional
`align_on` argument for the function `xr.DataArray.convert_calendar`.
Returns
-------
xr.Dataset
The bias-adjusted timeseries (dscen).
See Also
--------
xsdba.adjustment.DetrendedQuantileMapping : Detrended Quantile Mapping bias-adjustment.
xsdba.adjustment.ExtremeValues : Adjustment correction for extreme values.
Notes
-----
If `dref` is given as input, `dref` and `dsim` must have the same (non-zero) number
of time steps over training period given as input in ``xscen.train``.
"""
if xclim_adjust_args is not None:
warnings.warn(
"`xclim_adjust_args` will be deprecated and replaced by `xsdba_adjust_args`.",
FutureWarning,
stacklevel=2,
)
if xsdba_adjust_args is not None:
warnings.warn(
"`xclim_adjust_args` and `xsdba_adjust_args` were both given, but correspond to the same option. `xsdba_adjust_args` will be kept",
stacklevel=2,
)
else:
xsdba_adjust_args = deepcopy(xclim_adjust_args)
xsdba_adjust_args = deepcopy(xsdba_adjust_args)
xsdba_adjust_args = xsdba_adjust_args or {}
# evaluate the dict that was stored as a string
if not isinstance(dtrain.attrs["train_params"], dict):
# FIXME: eval is bad. There has to be a better way!™
dtrain.attrs["train_params"] = ast.literal_eval(dtrain.attrs["train_params"]) # noqa: S307
# transforms
additive_space = dtrain.attrs["train_params"]["additive_space"]
if additive_space:
for add_var, add_args in additive_space.items():
if dref:
dref[add_var] = to_additive_space(dref[add_var], **add_args)
dsim[add_var] = to_additive_space(dsim[add_var], **add_args)
var = dtrain.attrs["train_params"]["var"]
if len(var) == 1:
var = var[0]
sim = dsim[var]
else:
sim = xsdba.stack_variables(dsim[var])
# get right calendar
simcal = sim.time.dt.calendar
mincal = minimum_calendar(simcal, dtrain.attrs["train_params"]["maximal_calendar"])
if simcal != mincal:
sim = sim.convert_calendar(mincal, align_on=align_on)
# get right calendar for `dref` too, if defined
if dref is not None:
ref = xsdba.stack_variables(dref[var])
refcal = dref.time.dt.calendar
mincal = minimum_calendar(refcal, dtrain.attrs["train_params"]["maximal_calendar"])
if refcal != mincal:
ref = ref.convert_calendar(mincal, align_on=align_on)
# Used in MBCn adjusting (maybe other multivariate methods in the future)
train_period = dtrain.attrs["train_params"]["period"]
ref = ref.sel(time=slice(*train_period))
hist = sim.sel(time=slice(*train_period))
if (hist.time.size != ref.time.size) or (ref.time.size == 0):
raise ValueError(
" If `dref` was given as input, `dref` and `dsim` must have the same (non-zero) number of time steps "
"over the training period `period` defined in the ``xscen.train``, but this is not the case."
)
xsdba_adjust_args["ref"] = ref
xsdba_adjust_args["hist"] = hist
# adjust
ADJ = xsdba.adjustment.TrainAdjust.from_dataset(dtrain)
if ("detrend" in xsdba_adjust_args) and (isinstance(xsdba_adjust_args["detrend"], dict)):
name, kwargs = list(xsdba_adjust_args["detrend"].items())[0]
kwargs = kwargs or {}
kwargs.setdefault("group", ADJ.group)
kwargs.setdefault("kind", ADJ.kind)
xsdba_adjust_args["detrend"] = getattr(xsdba.detrending, name)(**kwargs)
with xclim_convert_units_to():
# do the adjustment for all the simulation_period lists
periods = standardize_periods(periods)
# if 'period_dim' is specified in 'xsdba_adjust_args', or if 'stack_periods' is given, use stacking
if xsdba_adjust_args.get("period_dim", None) is not None or stack_periods is not None:
if len(periods) > 1:
raise ValueError(
"Period stacking (`period_dim` specified in `xsdba_adjust_args`) is not allowed with multiple time slices in `periods`."
)
stack_periods = deepcopy(stack_periods) or {}
period_dim = xsdba_adjust_args.get("period_dim", "period")
if "dim" in stack_periods and stack_periods["dim"] != period_dim:
warnings.warn(
f"`dim` in `stack_periods` ({stack_periods['dim']}) is different from `period_dim` "
f"in `xsdba_adjust_args` ({period_dim}). Using `period_dim` value.",
UserWarning,
stacklevel=2,
)
stack_periods.pop("dim")
sim_stacked = xsdba.stack_periods(sim.sel(time=slice(*periods[0])), dim=period_dim, **stack_periods).chunk({period_dim: -1})
# Also stack `scen` if needed
if "scen" in xsdba_adjust_args:
scen = xsdba_adjust_args["scen"]
scen_stacked = xsdba.stack_periods(scen.sel(time=slice(*periods[0])), dim=period_dim, **stack_periods).chunk({period_dim: -1})
xsdba_adjust_args["scen"] = scen_stacked
out = ADJ.adjust(sim_stacked, **xsdba_adjust_args)
dscen = xsdba.unstack_periods(out, dim=period_dim)
# do the adjustment for all the simulation_period lists
else:
slices = []
for period in periods:
sim_sel = sim.sel(time=slice(period[0], period[1]))
out = ADJ.adjust(sim_sel, **xsdba_adjust_args)
slices.extend([out])
# put all the adjusted period back together
dscen = xr.concat(slices, dim="time")
dscen = _add_preprocessing_attr(dscen, dtrain.attrs["train_params"])
if isinstance(var, str):
dscen = xr.Dataset(data_vars={var: dscen}, attrs=dsim.attrs)
else:
dscen = xsdba.unstack_variables(dscen)
if additive_space:
for add_var in additive_space.keys():
dscen[add_var] = from_additive_space(dscen[add_var])
dscen.attrs["cat:processing_level"] = to_level
dscen.attrs["cat:variable"] = parse_from_ds(dscen, ["variable"])["variable"]
dscen.attrs["cat:bias_adjust_institution"] = bias_adjust_institution or "unknown"
dscen.attrs["cat:bias_adjust_project"] = bias_adjust_project or "unknown"
dscen.attrs["cat:bias_adjust_reference"] = bias_adjust_reference or dtrain.attrs.get("cat:bias_adjust_reference", None) or "unknown"
return dscen