Source code for xscen.io

"""Input/Output functions for xscen."""

import datetime
import logging
import os
import shutil as sh
from collections import defaultdict
from collections.abc import Sequence
from inspect import signature
from pathlib import Path
from typing import Optional, Union

import h5py
import netCDF4
import numpy as np
import pandas as pd
import xarray as xr
import zarr
from numcodecs.bitround import BitRound
from rechunker import rechunk as _rechunk
from xclim.core.calendar import get_calendar
from xclim.core.options import METADATA_LOCALES
from xclim.core.options import OPTIONS as XC_OPTIONS

from .config import parse_config
from .scripting import TimeoutException
from .utils import TRANSLATOR, season_sort_key, translate_time_chunk

logger = logging.getLogger(__name__)
KEEPBITS = defaultdict(lambda: 12)


__all__ = [
    "clean_incomplete",
    "estimate_chunks",
    "get_engine",
    "make_toc",
    "rechunk",
    "rechunk_for_saving",
    "round_bits",
    "save_to_netcdf",
    "save_to_table",
    "save_to_zarr",
    "subset_maxsize",
    "to_table",
]


[docs] def get_engine(file: Union[str, os.PathLike]) -> str: """Use functionality of h5py to determine if a NetCDF file is compatible with h5netcdf. Parameters ---------- file : str or os.PathLike Path to the file. Returns ------- str Engine to use with xarray """ # find the ideal engine for xr.open_mfdataset if Path(file).suffix == ".zarr": engine = "zarr" elif h5py.is_hdf5(file): engine = "h5netcdf" else: engine = "netcdf4" return engine
[docs] def estimate_chunks( # noqa: C901 ds: Union[str, os.PathLike, xr.Dataset], dims: list, target_mb: float = 50, chunk_per_variable: bool = False, ) -> dict: """Return an approximate chunking for a file or dataset. Parameters ---------- ds : xr.Dataset, str Either a xr.Dataset or the path to a NetCDF file. Existing chunks are not taken into account. dims : list Dimension(s) on which to estimate the chunking. Not implemented for more than 2 dimensions. target_mb : float Roughly the size of chunks (in Mb) to aim for. chunk_per_variable : bool If True, the output will be separated per variable. Otherwise, a common chunking will be found. Returns ------- dict A dictionary mapping dimensions to chunk sizes. """ def _estimate_chunks(ds, target_mb, size_of_slice, rechunk_dims): # Approximate size of the chunks (equal across dims) approx_chunks = np.power(target_mb / size_of_slice, 1 / len(rechunk_dims)) chunks_per_dim = dict() if len(rechunk_dims) == 1: rounding = ( 1 if ds[rechunk_dims[0]].shape[0] <= 15 else 5 if ds[rechunk_dims[0]].shape[0] <= 250 else 10 ) chunks_per_dim[rechunk_dims[0]] = np.max( [ np.min( [ int(rounding * np.round(approx_chunks / rounding)), ds[rechunk_dims[0]].shape[0], ] ), 1, ] ) elif len(rechunk_dims) == 2: # Adjust approx_chunks based on the ratio of the rectangle sizes for d in rechunk_dims: rounding = ( 1 if ds[d].shape[0] <= 15 else 5 if ds[d].shape[0] <= 250 else 10 ) adjusted_chunk = int( rounding * np.round( approx_chunks * ( ds[d].shape[0] / np.prod( [ ds[dd].shape[0] for dd in rechunk_dims if dd not in [d] ] ) ) / rounding ) ) chunks_per_dim[d] = np.max( [np.min([adjusted_chunk, ds[d].shape[0]]), 1] ) else: raise NotImplementedError( "estimating chunks on more than 2 dimensions is not implemented yet." ) return chunks_per_dim out = {} # If ds is the path to a file, use NetCDF4 if isinstance(ds, (str, os.PathLike)): ds = netCDF4.Dataset(ds, "r") # Loop on variables for v in ds.variables: # Find dimensions to chunk rechunk_dims = list(set(dims).intersection(ds.variables[v].dimensions)) if not rechunk_dims: continue dtype_size = ds.variables[v].datatype.itemsize num_elem_per_slice = np.prod( [ds[d].shape[0] for d in ds[v].dimensions if d not in rechunk_dims] ) size_of_slice = (num_elem_per_slice * dtype_size) / 1024**2 estimated_chunks = _estimate_chunks( ds, target_mb, size_of_slice, rechunk_dims ) for other in set(ds[v].dimensions).difference(dims): estimated_chunks[other] = -1 if chunk_per_variable: out[v] = estimated_chunks else: for d in estimated_chunks: if (d not in out) or (out[d] > estimated_chunks[d]): out[d] = estimated_chunks[d] # Else, use xarray else: for v in ds.data_vars: # Find dimensions to chunk rechunk_dims = list(set(dims).intersection(ds[v].dims)) if not rechunk_dims: continue dtype_size = ds[v].dtype.itemsize num_elem_per_slice = np.prod( [ds[d].shape[0] for d in ds[v].dims if d not in rechunk_dims] ) size_of_slice = (num_elem_per_slice * dtype_size) / 1024**2 estimated_chunks = _estimate_chunks( ds, target_mb, size_of_slice, rechunk_dims ) for other in set(ds[v].dims).difference(dims): estimated_chunks[other] = -1 if chunk_per_variable: out[v] = estimated_chunks else: for d in estimated_chunks: if (d not in out) or (out[d] > estimated_chunks[d]): out[d] = estimated_chunks[d] return out
[docs] def subset_maxsize( ds: xr.Dataset, maxsize_gb: float, ) -> list: """Estimate a dataset's size and, if higher than the given limit, subset it alongside the 'time' dimension. Parameters ---------- ds : xr.Dataset Dataset to be saved. maxsize_gb : float Target size for the NetCDF files. If the dataset is bigger than this number, it will be separated alongside the 'time' dimension. Returns ------- list List of xr.Dataset subsetted alongside 'time' to limit the filesize to the requested maximum. """ # Estimate the size of the dataset size_of_file = 0 for v in ds: dtype_size = ds[v].dtype.itemsize varsize = np.prod(list(ds[v].sizes.values())) size_of_file = size_of_file + (varsize * dtype_size) / 1024**3 if size_of_file < maxsize_gb: logger.info(f"Dataset is already smaller than {maxsize_gb} Gb.") return [ds] elif "time" in ds: years = np.unique(ds.time.dt.year) ratio = int(len(years) / (size_of_file / maxsize_gb)) ds_sub = [] for y in range(years[0], years[-1], ratio): ds_sub.extend([ds.sel({"time": slice(str(y), str(y + ratio - 1))})]) return ds_sub else: raise NotImplementedError( f"Size of the NetCDF file exceeds the {maxsize_gb} Gb target, but the dataset does not contain a 'time' variable." )
[docs] def clean_incomplete(path: Union[str, os.PathLike], complete: Sequence[str]) -> None: """Delete un-catalogued variables from a zarr folder. The goal of this function is to clean up an incomplete calculation. It will remove any variable in the zarr that is neither in the `complete` list nor in the `coords`. Parameters ---------- path : str, Path A path to a zarr folder. complete : sequence of strings Name of variables that were completed. Returns ------- None """ path = Path(path) with xr.open_zarr(path) as ds: complete = set(complete).union(ds.coords.keys()) for fold in filter(lambda p: p.is_dir(), path.iterdir()): if fold.name not in complete: logger.warning(f"Removing {fold} from disk") sh.rmtree(fold)
def _coerce_attrs(attrs): """Ensure no funky objects in attrs.""" for k in list(attrs.keys()): if not ( isinstance(attrs[k], (str, float, int, np.ndarray)) or isinstance(attrs[k], (tuple, list)) and isinstance(attrs[k][0], (str, float, int)) ): attrs[k] = str(attrs[k]) def _np_bitround(array: xr.DataArray, keepbits: int): """Bitround for Arrays.""" codec = BitRound(keepbits=keepbits) data = array.copy() # otherwise overwrites the input encoded = codec.encode(data) return codec.decode(encoded)
[docs] def round_bits(da: xr.DataArray, keepbits: int): """Round floating point variable by keeping a given number of bits in the mantissa, dropping the rest. This allows for a much better compression. Parameters ---------- da : xr.DataArray Variable to be rounded. keepbits : int The number of bits of the mantissa to keep. """ da = xr.apply_ufunc( _np_bitround, da, keepbits, dask="parallelized", keep_attrs=True ) da.attrs["_QuantizeBitRoundNumberOfSignificantDigits"] = keepbits new_history = f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Data compressed with BitRound by keeping {keepbits} bits." history = ( new_history + " \n " + da.attrs["history"] if "history" in da.attrs else new_history ) da.attrs["history"] = history return da
def _get_keepbits(bitround: Union[bool, int, dict], varname: str, vartype): # Guess the number of bits to keep depending on how bitround was passed, the var dtype and the var name. if not np.issubdtype(vartype, np.floating) or bitround is False: if isinstance(bitround, dict) and varname in bitround: raise ValueError( f"A keepbits value was given for variable {varname} even though it is not of a floating dtype." ) return None if bitround is True: return KEEPBITS[varname] if isinstance(bitround, int): return bitround if isinstance(bitround, dict): return bitround.get(varname, KEEPBITS[varname]) return None
[docs] @parse_config def save_to_netcdf( ds: xr.Dataset, filename: Union[str, os.PathLike], *, rechunk: Optional[dict] = None, bitround: Union[bool, int, dict] = False, compute: bool = True, netcdf_kwargs: Optional[dict] = None, ): """Save a Dataset to NetCDF, rechunking or compressing if requested. Parameters ---------- ds : xr.Dataset Dataset to be saved. filename : str or os.PathLike Name of the NetCDF file to be saved. rechunk : dict, optional This is a mapping from dimension name to new chunks (in any format understood by dask). Spatial dimensions can be generalized as 'X' and 'Y', which will be mapped to the actual grid type's dimension names. Rechunking is only done on *data* variables sharing dimensions with this argument. bitround : bool or int or dict If not False, float variables are bit-rounded by dropping a certain number of bits from their mantissa, allowing for a much better compression. If an int, this is the number of bits to keep for all float variables. If a dict, a mapping from variable name to the number of bits to keep. If True, the number of bits to keep is guessed based on the variable's name, defaulting to 12, which yields a relative error below 0.013%. compute : bool Whether to start the computation or return a delayed object. netcdf_kwargs : dict, optional Additional arguments to send to_netcdf() Returns ------- None See Also -------- xarray.Dataset.to_netcdf """ if rechunk: ds = rechunk_for_saving(ds, rechunk) path = Path(filename) path.parent.mkdir(parents=True, exist_ok=True) # Prepare to_netcdf kwargs netcdf_kwargs = netcdf_kwargs or {} netcdf_kwargs.setdefault("engine", "h5netcdf") netcdf_kwargs.setdefault("format", "NETCDF4") for var in list(ds.data_vars.keys()): if keepbits := _get_keepbits(bitround, var, ds[var].dtype): ds = ds.assign({var: round_bits(ds[var], keepbits)}) # Remove original_shape from encoding, since it can cause issues with some engines. ds[var].encoding.pop("original_shape", None) _coerce_attrs(ds.attrs) for var in ds.variables.values(): _coerce_attrs(var.attrs) return ds.to_netcdf(filename, compute=compute, **netcdf_kwargs)
[docs] @parse_config def save_to_zarr( # noqa: C901 ds: xr.Dataset, filename: Union[str, os.PathLike], *, rechunk: Optional[dict] = None, zarr_kwargs: Optional[dict] = None, compute: bool = True, encoding: Optional[dict] = None, bitround: Union[bool, int, dict] = False, mode: str = "f", itervar: bool = False, timeout_cleanup: bool = True, ): """Save a Dataset to Zarr format, rechunking and compressing if requested. According to mode, removes variables that we don't want to re-compute in ds. Parameters ---------- ds : xr.Dataset Dataset to be saved. filename : str Name of the Zarr file to be saved. rechunk : dict, optional This is a mapping from dimension name to new chunks (in any format understood by dask). Spatial dimensions can be generalized as 'X' and 'Y' which will be mapped to the actual grid type's dimension names. Rechunking is only done on *data* variables sharing dimensions with this argument. zarr_kwargs : dict, optional Additional arguments to send to_zarr() compute : bool Whether to start the computation or return a delayed object. mode : {'f', 'o', 'a'} If 'f', fails if any variable already exists. if 'o', removes the existing variables. if 'a', skip existing variables, writes the others. encoding : dict, optional If given, skipped variables are popped in place. bitround : bool or int or dict If not False, float variables are bit-rounded by dropping a certain number of bits from their mantissa, allowing for a much better compression. If an int, this is the number of bits to keep for all float variables. If a dict, a mapping from variable name to the number of bits to keep. If True, the number of bits to keep is guessed based on the variable's name, defaulting to 12, which yields a relative error of 0.012%. itervar : bool If True, (data) variables are written one at a time, appending to the zarr. If False, this function computes, no matter what was passed to kwargs. timeout_cleanup : bool If True (default) and a :py:class:`xscen.scripting.TimeoutException` is raised during the writing, the variable being written is removed from the dataset as it is incomplete. This does nothing if `compute` is False. Returns ------- dask.delayed object if compute=False, None otherwise. See Also -------- xarray.Dataset.to_zarr """ # to address this issue https://github.com/pydata/xarray/issues/3476 for v in list(ds.coords.keys()): if ds.coords[v].dtype == object: ds[v].encoding.clear() if rechunk: ds = rechunk_for_saving(ds, rechunk) path = Path(filename) path.parent.mkdir(parents=True, exist_ok=True) if path.is_dir(): tgtds = zarr.open(str(path), mode="r") else: tgtds = {} if encoding: encoding = encoding.copy() # Prepare to_zarr kwargs if zarr_kwargs is None: zarr_kwargs = {} def _skip(var): exists = var in tgtds if mode == "f" and exists: raise ValueError(f"Variable {var} exists in dataset {path}.") if mode == "o": if exists: var_path = path / var logger.warning(f"Removing {var_path} to overwrite.") sh.rmtree(var_path) return False if mode == "a": if "append_dim" not in zarr_kwargs: return exists return False for var in list(ds.data_vars.keys()): if _skip(var): logger.info(f"Skipping {var} in {path}.") ds = ds.drop_vars(var) if encoding: encoding.pop(var) if keepbits := _get_keepbits(bitround, var, ds[var].dtype): ds = ds.assign({var: round_bits(ds[var], keepbits)}) # Remove original_shape from encoding, since it can cause issues with some engines. ds[var].encoding.pop("original_shape", None) if len(ds.data_vars) == 0: return None _coerce_attrs(ds.attrs) for var in ds.variables.values(): _coerce_attrs(var.attrs) if itervar: zarr_kwargs["compute"] = True allvars = set(ds.data_vars.keys()) if mode == "f": dsbase = ds.drop_vars(allvars) dsbase.to_zarr(path, **zarr_kwargs) if mode == "o": dsbase = ds.drop_vars(allvars) dsbase.to_zarr(path, **zarr_kwargs, mode="w") for i, (name, var) in enumerate(ds.data_vars.items()): logger.debug(f"Writing {name} ({i + 1} of {len(ds.data_vars)}) to {path}") dsvar = ds.drop_vars(allvars - {name}) try: dsvar.to_zarr( path, mode="a", encoding={k: v for k, v in (encoding or {}).items() if k in dsvar}, **zarr_kwargs, ) except TimeoutException: if timeout_cleanup: logger.info(f"Removing incomplete {name}.") sh.rmtree(path / name) raise else: logger.debug(f"Writing {list(ds.data_vars.keys())} for {filename}.") try: return ds.to_zarr( filename, compute=compute, mode="a", encoding=encoding, **zarr_kwargs ) except TimeoutException: if timeout_cleanup: logger.info( f"Removing incomplete {list(ds.data_vars.keys())} for {filename}." ) for name in ds.data_vars: sh.rmtree(path / name) raise
def _to_dataframe( data: xr.DataArray, row: list[str], column: list[str], coords: list[str], coords_dims: dict, ): """Convert a DataArray to a DataFrame with support for MultiColumn.""" df = data.to_dataframe() if not column: # Fast track for the easy case where xarray's default is already what we want. return df df_data = ( df[[data.name]] .reset_index() .pivot(index=row, columns=column) .droplevel(None, axis=1) ) dfs = [] for v in coords: drop_cols = [c for c in column if c not in coords_dims[v]] cols = [c for c in column if c in coords_dims[v]] dfc = ( df[[v]].reset_index().drop(columns=drop_cols).pivot(index=row, columns=cols) ) cols = dfc.columns # The "None" level has the aux coord name we want it either at the same level as variable, or at lowest missing level otherwise. varname_lvl = "variable" if "variable" in drop_cols else drop_cols[-1] cols = cols.rename( varname_lvl if not isinstance(cols, pd.MultiIndex) else [nm or varname_lvl for nm in cols.name] ) if isinstance(df_data.columns, pd.MultiIndex) or isinstance( cols, pd.MultiIndex ): # handle different depth of multicolumns, expand MultiCol of coord with None for missing levels. cols = pd.MultiIndex.from_arrays( [ cols.get_level_values(lvl) if lvl in cols.names else [None] for lvl in df_data.columns.names ], names=df_data.columns.names, ) dfc.columns = cols dfs.append( dfc[~dfc.index.duplicated()] ) # We dropped columns thus the index is not unique anymore dfs.append(df_data) return pd.concat(dfs, axis=1).sort_index(level=row, key=season_sort_key)
[docs] def to_table( ds: Union[xr.Dataset, xr.DataArray], *, row: Optional[Union[str, Sequence[str]]] = None, column: Optional[Union[str, Sequence[str]]] = None, sheet: Optional[Union[str, Sequence[str]]] = None, coords: Union[bool, str, Sequence[str]] = True, ) -> Union[pd.DataFrame, dict]: """Convert a dataset to a pandas DataFrame with support for multicolumns and multisheet. This function will trigger a computation of the dataset. Parameters ---------- ds : xr.Dataset or xr.DataArray Dataset or DataArray to be saved. If a Dataset with more than one variable is given, the dimension "variable" must appear in one of `row`, `column` or `sheet`. row : str or sequence of str, optional Name of the dimension(s) to use as indexes (rows). Default is all data dimensions. column : str or sequence of str, optional Name of the dimension(s) to use as columns. Default is "variable", i.e. the name of the variable(s). sheet : str or sequence of str, optional Name of the dimension(s) to use as sheet names. coords: bool or str or sequence of str A list of auxiliary coordinates to add to the columns (as would variables). If True, all (if any) are added. Returns ------- pd.DataFrame or dict DataFrame with a MultiIndex with levels `row` and MultiColumn with levels `column`. If `sheet` is given, the output is dictionary with keys for each unique "sheet" dimensions tuple, values are DataFrames. The DataFrames are always sorted with level priority as given in `row` and in ascending order. """ if isinstance(ds, xr.Dataset): da = ds.to_array(name="data") if len(ds) == 1: da = da.isel(variable=0).rename(data=da.variable.values[0]) def _ensure_list(seq): if isinstance(seq, str): return [seq] return list(seq) passed_dims = set().union( _ensure_list(row or []), _ensure_list(column or []), _ensure_list(sheet or []) ) if row is None: row = [d for d in da.dims if d != "variable" and d not in passed_dims] row = _ensure_list(row) if column is None: column = ["variable"] if len(ds) > 1 and "variable" not in passed_dims else [] column = _ensure_list(column) if sheet is None: sheet = [] sheet = _ensure_list(sheet) needed_dims = row + column + sheet if len(set(needed_dims)) != len(needed_dims): raise ValueError( f"Repeated dimension names. Got row={row}, column={column} and sheet={sheet}." "Each dimension should appear only once." ) if set(needed_dims) != set(da.dims): raise ValueError( f"Passed row, column and sheet do not match available dimensions. Got {needed_dims}, data has {da.dims}." ) if coords is not True: coords = _ensure_list(coords or []) drop = set(ds.coords.keys()) - set(da.dims) - set(coords) da = da.drop_vars(drop) else: coords = list(set(ds.coords.keys()) - set(da.dims)) if len(coords) > 1 and ("variable" in row or "variable" in sheet): raise NotImplementedError( "Keeping auxiliary coords is not implemented when 'variable' is in the row or in the sheets." "Pass `coords=False` or put 'variable' in `column` instead." ) table_kwargs = dict( row=row, column=column, coords=coords, coords_dims={c: ds[c].dims for c in coords}, ) if sheet: out = {} das = da.stack(sheet=sheet) for elem in das.sheet: out[elem.item()] = _to_dataframe( das.sel(sheet=elem, drop=True), **table_kwargs ) return out return _to_dataframe(da, **table_kwargs)
[docs] def make_toc( ds: Union[xr.Dataset, xr.DataArray], loc: Optional[str] = None ) -> pd.DataFrame: """Make a table of content describing a dataset's variables. This return a simple DataFrame with variable names as index, the long_name as "description" and units. Column names and long names are taken from the activated locale if found, otherwise the english version is taken. Parameters ---------- ds : xr.Dataset or xr.DataArray Dataset or DataArray from which to extract the relevant metadata. loc : str, optional The locale to use. If None, either the first locale in the list of activated xclim locales is used, or "en" if none is activated. Returns ------- pd.DataFrame A DataFrame with variables as index, and columns "description" and "units". """ if loc is None: loc = (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0] locsuf = "" if loc == "en" else f"_{loc}" _ = TRANSLATOR[loc] # Combine translation and gettext parsing (like it usually is) if isinstance(ds, xr.DataArray): ds = ds.to_dataset() toc = pd.DataFrame.from_records( [ { _("Variable"): vv, _("Description"): da.attrs.get( f"long_name{locsuf}", da.attrs.get("long_name") ), _("Units"): da.attrs.get("units"), } for vv, da in ds.data_vars.items() ], ).set_index(_("Variable")) toc.attrs["name"] = _("Content") return toc
TABLE_FORMATS = {".csv": "csv", ".xls": "excel", ".xlsx": "excel"}
[docs] def save_to_table( ds: Union[xr.Dataset, xr.DataArray], filename: Union[str, os.PathLike], output_format: Optional[str] = None, *, row: Optional[Union[str, Sequence[str]]] = None, column: Union[None, str, Sequence[str]] = "variable", sheet: Optional[Union[str, Sequence[str]]] = None, coords: Union[bool, Sequence[str]] = True, col_sep: str = "_", row_sep: Optional[str] = None, add_toc: Union[bool, pd.DataFrame] = False, **kwargs, ): """Save the dataset to a tabular file (csv, excel, ...). This function will trigger a computation of the dataset. Parameters ---------- ds : xr.Dataset or xr.DataArray Dataset or DataArray to be saved. If a Dataset with more than one variable is given, the dimension "variable" must appear in one of `row`, `column` or `sheet`. filename : str or os.PathLike Name of the file to be saved. output_format: {'csv', 'excel', ...}, optional The output format. If None (default), it is inferred from the extension of `filename`. Not all possible output format are supported for inference. Valid values are any that matches a :py:class:`pandas.DataFrame` method like "df.to_{format}". row : str or sequence of str, optional Name of the dimension(s) to use as indexes (rows). Default is all data dimensions. column : str or sequence of str, optional Name of the dimension(s) to use as columns. Default is "variable", i.e. the name of the variable(s). sheet : str or sequence of str, optional Name of the dimension(s) to use as sheet names. Only valid if the output format is excel. coords: bool or sequence of str A list of auxiliary coordinates to add to the columns (as would variables). If True, all (if any) are added. col_sep : str, Multi-columns (except in excel) and sheet names are concatenated with this separator. row_sep : str, optional Multi-index names are concatenated with this separator, except in excel. If None (default), each level is written in its own column. add_toc : bool or DataFrame A table of content to add as the first sheet. Only valid if the output format is excel. If True, :py:func:`make_toc` is used to generate the toc. The sheet name of the toc can be given through the "name" attribute of the DataFrame, otherwise "Content" is used. kwargs: Other arguments passed to the pandas function. If the output format is excel, kwargs to :py:class:`pandas.ExcelWriter` can be given here as well. """ filename = Path(filename) if output_format is None: output_format = TABLE_FORMATS.get(filename.suffix) if output_format is None: raise ValueError( f"Output format could not be inferred from filename {filename.name}. Please pass `output_format`." ) if sheet is not None and output_format != "excel": raise ValueError( f"Argument `sheet` is only valid with excel as the output format. Got {output_format}." ) if add_toc is not False and output_format != "excel": raise ValueError( f"A TOC was requested, but the output format is not Excel. Got {output_format}." ) out = to_table(ds, row=row, column=column, sheet=sheet, coords=coords) if add_toc is not False: if not sheet: out = {("data",): out} if add_toc is True: add_toc = make_toc(ds) out = {(add_toc.attrs.get("name", "Content"),): add_toc, **out} if sheet or (add_toc is not False): engine_kwargs = {} # Extract engine kwargs for arg in signature(pd.ExcelWriter).parameters: if arg in kwargs: engine_kwargs[arg] = kwargs.pop(arg) with pd.ExcelWriter(filename, **engine_kwargs) as writer: for sheet_name, df in out.items(): df.to_excel(writer, sheet_name=col_sep.join(sheet_name), **kwargs) else: if output_format != "excel" and isinstance(out.columns, pd.MultiIndex): out.columns = out.columns.map(lambda lvls: col_sep.join(map(str, lvls))) if ( output_format != "excel" and row_sep is not None and isinstance(out.index, pd.MultiIndex) ): new_name = row_sep.join(out.index.names) out.index = out.index.map(lambda lvls: row_sep.join(map(str, lvls))) out.index.name = new_name getattr(out, f"to_{output_format}")(filename, **kwargs)
[docs] def rechunk_for_saving(ds: xr.Dataset, rechunk: dict): """Rechunk before saving to .zarr or .nc, generalized as Y/X for different axes lat/lon, rlat/rlon. Parameters ---------- ds : xr.Dataset The xr.Dataset to be rechunked. rechunk : dict A dictionary with the dimension names of ds and the new chunk size. Spatial dimensions can be provided as X/Y. Returns ------- xr.Dataset The dataset with new chunking. """ for rechunk_var in ds.data_vars: # Support for chunks varying per variable if rechunk_var in rechunk: rechunk_dims = rechunk[rechunk_var].copy() else: rechunk_dims = rechunk.copy() # get actual axes labels if "X" in rechunk_dims and "X" not in ds.dims: rechunk_dims[ds.cf.axes["X"][0]] = rechunk_dims.pop("X") if "Y" in rechunk_dims and "Y" not in ds.dims: rechunk_dims[ds.cf.axes["Y"][0]] = rechunk_dims.pop("Y") ds[rechunk_var] = ds[rechunk_var].chunk( {d: chnks for d, chnks in rechunk_dims.items() if d in ds[rechunk_var].dims} ) ds[rechunk_var].encoding["chunksizes"] = tuple( rechunk_dims[d] if d in rechunk_dims else ds[d].shape[0] for d in ds[rechunk_var].dims ) ds[rechunk_var].encoding.pop("chunks", None) ds[rechunk_var].encoding.pop("preferred_chunks", None) return ds
[docs] @parse_config def rechunk( path_in: Union[os.PathLike, str, xr.Dataset], path_out: Union[os.PathLike, str], *, chunks_over_var: Optional[dict] = None, chunks_over_dim: Optional[dict] = None, worker_mem: str, temp_store: Optional[Union[os.PathLike, str]] = None, overwrite: bool = False, ) -> None: """Rechunk a dataset into a new zarr. Parameters ---------- path_in : path, str or xr.Dataset Input to rechunk. path_out : path or str Path to the target zarr. chunks_over_var : dict Mapping from variables to mappings from dimension name to size. Give this argument or `chunks_over_dim`. chunks_over_dim : dict Mapping from dimension name to size that will be used for all variables in ds. Give this argument or `chunks_over_var`. worker_mem : str The maximal memory usage of each task. When using a distributed Client, this an approximate memory per thread. Each worker of the client should have access to 10-20% more memory than this times the number of threads. temp_store : path or str, optional A path to a zarr where to store intermediate results. overwrite : bool If True, it will delete whatever is in path_out before doing the rechunking. Returns ------- None See Also -------- rechunker.rechunk """ if Path(path_out).is_dir() and overwrite: sh.rmtree(path_out) if isinstance(path_in, os.PathLike) or isinstance(path_in, str): path_in = Path(path_in) if path_in.suffix == ".zarr": ds = xr.open_zarr(path_in) else: ds = xr.open_dataset(path_in) else: ds = path_in variables = list(ds.data_vars) if chunks_over_var: chunks = chunks_over_var elif chunks_over_dim: chunks = {v: {d: chunks_over_dim[d] for d in ds[v].dims} for v in variables} chunks.update(time=None, lat=None, lon=None) cal = get_calendar(ds) Nt = ds.time.size chunks = translate_time_chunk(chunks, cal, Nt) else: raise ValueError( "No chunks given. Need to give at `chunks_over_var` or `chunks_over_dim`." ) plan = _rechunk(ds, chunks, worker_mem, str(path_out), temp_store=str(temp_store)) plan.execute() zarr.consolidate_metadata(path_out) if temp_store is not None: sh.rmtree(temp_store)