Source code for sharrow.relationships

import ast
import logging
import warnings

import networkx as nx
import numpy as np
import pandas as pd
import xarray as xr

from .dataset import Dataset, construct

try:
    from dask.array import Array as dask_array_type
except ModuleNotFoundError:
    dask_array_type = ()

try:
    from sparse import SparseArray as sparse_array_type
except ModuleNotFoundError:
    sparse_array_type = ()

try:
    from ast import unparse
except ImportError:
    from astunparse import unparse as _unparse

    def unparse(*args):
        return _unparse(*args).strip("\n")


logger = logging.getLogger("sharrow")

well_known_names = {
    "nb",
    "np",
    "pd",
    "xr",
    "pa",
    "log",
    "exp",
    "log1p",
    "expm1",
    "max",
    "min",
    "piece",
    "hard_sigmoid",
    "transpose_leading",
    "clip",
}

NOTSET = "<--NOTSET-->"


def _require_string(x):
    if not isinstance(x, str):
        raise ValueError("must be string")
    return x


def _iat(source, *, _names=None, _load=False, _index_name=None, **idxs):
    loaders = {}
    inum = 0

    def _ixname():
        if _index_name is not None:
            return _index_name
        nonlocal inum
        inum += 1
        return f"index{inum}"

    for k, v in idxs.items():
        loaders[k] = xr.DataArray(v, dims=[_ixname() for n in range(v.ndim)])
    if _names:
        ds = source[_names]
    else:
        ds = source
    if _load:
        ds = ds._load()
    return ds.isel(**loaders)


def _at(source, *, _names=None, _load=False, _index_name=None, **idxs):
    loaders = {}
    inum = 0

    def _ixname():
        if _index_name is not None:
            return _index_name
        nonlocal inum
        inum += 1
        return f"index{inum}"

    for k, v in idxs.items():
        loaders[k] = xr.DataArray(v, dims=[_ixname() for n in range(v.ndim)])
    if _names:
        ds = source[_names]
    else:
        ds = source
    if _load:
        ds = ds._load()
    return ds.sel(**loaders)


def gather(source, indexes):
    """
    Extract values by label on the coordinates indicated by columns of a DataFrame.

    Parameters
    ----------
    source : xarray.DataArray or xarray.Dataset
        The source of the values to extract.
    indexes : Mapping[str, array-like]
        The keys of `indexes` (if given as a dataframe, the column names)
        should match the named dimensions of `source`.  The resulting extracted
        data will have a shape one row per row of `df`, and columns matching
        the data variables in `source`, and each value is looked up by the labels.

    Returns
    -------
    pd.DataFrame
    """
    result = _at(source, **indexes).reset_coords(drop=True)
    return result


def igather(source, positions):
    """
    Extract values by position on the coordinates indicated by columns of a DataFrame.

    Parameters
    ----------
    source : xarray.DataArray or xarray.Dataset
    positions : pd.DataFrame or Mapping[str, array-like]
        The columns (or keys) of `df` should match the named dimensions of
        this Dataset.  The resulting extracted DataFrame will have one row
        per row of `df`, columns matching the data variables in this dataset,
        and each value is looked up by the positions.

    Returns
    -------
    pd.DataFrame
    """
    result = _iat(source, **positions).reset_coords(drop=True)
    return result


def xgather(source, positions, indexes):
    if len(indexes) == 0:
        return igather(source, positions)
    elif len(positions) == 0:
        return gather(source, indexes)
    else:
        return gather(igather(source, positions), indexes)


def _dataarray_to_numpy(self) -> np.ndarray:
    """Coerces wrapped data to numpy and returns a numpy.ndarray."""
    data = self.data
    if isinstance(data, dask_array_type):
        data = data.compute()
    if isinstance(data, sparse_array_type):
        data = data.todense()
    data = np.asarray(data)
    return data


[docs] class Relationship: """Defines a linkage between datasets in a `DataTree`.""" def __init__( self, parent_data, parent_name, child_data, child_name, indexing="label", analog=None, ): self.parent_data = _require_string(parent_data) """str: Name of the parent dataset.""" self.parent_name = _require_string(parent_name) """str: Variable in the parent dataset that references the child dimension.""" self.child_data = _require_string(child_data) """str: Name of the child dataset.""" self.child_name = _require_string(child_name) """str: Dimension in the child dataset that is used by this relationship.""" if indexing not in {"label", "position"}: raise ValueError("indexing must be by label or position") self.indexing = indexing """str: How the target dimension is used, either by 'label' or 'position'.""" self.analog = analog """str: Original variable that defined label-based relationship before digitization.""" def __eq__(self, other): if isinstance(other, self.__class__): if self.analog: left = ( f"<Relationship by label: " f"{self.parent_data}[{self.analog!r}] -> " f"{self.child_data}[{self.child_name!r}]>" ) else: left = repr(self) if other.analog: right = ( f"<Relationship by label: " f"{other.parent_data}[{other.analog!r}] -> " f"{other.child_data}[{other.child_name!r}]>" ) else: right = repr(other) return left == right def __repr__(self): return ( f"<Relationship by {self.indexing}: " f"{self.parent_data}[{self.parent_name!r}] -> " f"{self.child_data}[{self.child_name!r}]>" ) def attrs(self): return dict( parent_name=self.parent_name, child_name=self.child_name, indexing=self.indexing, ) def to_dict(self): return dict( parent_data=self.parent_data, parent_name=self.parent_name, child_data=self.child_data, child_name=self.child_name, indexing=self.indexing, analog=self.analog, ) @classmethod def from_string(cls, s): """ Construct a `Relationship` from a string. Parameters ---------- s : str The relationship definition. To create a label-based relationship, the string should look like "ParentNode.variable_name @ ChildNode.dimension_name". To create a position-based relationship, give "ParentNode.variable_name -> ChildNode.dimension_name". Returns ------- Relationship """ if "->" in s: parent, child = s.split("->", 1) i = "position" elif "@" in s: parent, child = s.split("@", 1) i = "label" else: raise ValueError(f"cannot interpret relationship {s!r}") p1, p2 = parent.split(".", 1) c1, c2 = child.split(".", 1) p1 = p1.strip() p2 = p2.strip() c1 = c1.strip() c2 = c2.strip() return cls( parent_data=p1, parent_name=p2, child_data=c1, child_name=c2, indexing=i, )
[docs] class DataTree: """ A tree representing linked datasets, from which data can flow. Parameters ---------- graph : networkx.MultiDiGraph root_node_name : str or False The name of the node at the root of the tree. extra_funcs : Tuple[Callable] Additional functions that can be called by Flow objects created using this DataTree. These functions should have defined `__name__` attributes, so they can be called in expressions. extra_vars : Mapping[str,Any], optional Additional named constants that can be referenced by expressions in Flow objects created using this DataTree. cache_dir : Path-like, optional The default directory where Flow objects are created. relationships : Iterable[str or Relationship] The relationship definitions used to define this tree. All dataset nodes named in these relationships should also be included as keyword arguments for this constructor. force_digitization : bool, default False Whether to automatically digitize all relationships (converting them from label-based to position-based). Digitization is required to evaluate Flows, but doing so automatically on construction may be inefficient. dim_order : Tuple[str], optional The order of dimensions to use in Flow outputs. Generally only needed if there are multiple dimensions in the root dataset. aux_vars : Mapping[str,Any], optional Additional named arrays or numba-typable variables that can be referenced by expressions in Flow objects created using this DataTree. """ DatasetType = Dataset def __init__( self, graph=None, root_node_name=None, extra_funcs=(), extra_vars=None, cache_dir=None, relationships=(), force_digitization=False, dim_order=None, aux_vars=None, **kwargs, ): if isinstance(graph, Dataset): raise ValueError("datasets must be given as keyword arguments") # raw init if graph is None: graph = nx.MultiDiGraph() self._graph = graph self._root_node_name = None self.force_digitization = force_digitization self.dim_order = dim_order self.dim_exclude = set() # defined init if root_node_name is not None and root_node_name in kwargs: self.add_dataset(root_node_name, kwargs[root_node_name]) self.root_node_name = root_node_name self.extra_funcs = extra_funcs self.extra_vars = extra_vars or {} self.aux_vars = aux_vars or {} self.cache_dir = cache_dir self._cached_indexes = {} for k, v in kwargs.items(): if root_node_name is not None and k == root_node_name: continue self.add_dataset(k, v) for r in relationships: self.add_relationship(r) if force_digitization: self.digitize_relationships(inplace=True) # These filters are applied to incoming datasets when using `replace_datasets`. self.replacement_filters = {} """Dict[Str,Callable]: Filters that are automatically applied to data on replacement. When individual datasets are replaced in the tree, the incoming dataset is passed through the filter with a matching name-key (if it exists). The filter should be a function that accepts one argument (the incoming dataset) and returns one value (the dataset to save in the tree). These filters can be used to ensure data quality, e.g. renaming variables, ensuring particular data types, etc. """ self.subspace_fallbacks = {} """Dict[Str:List[Str]]: Allowable fallback subspace lookups. When a named variable is not found in a given subspace, the default result is raising a KeyError. But, if fallbacks are defined for a given subspace, the fallbacks are searched in order for the desired variable. """ @property def shape(self): """Tuple[int]: base shape of arrays that will be loaded when using this DataTree.""" if self.dim_order: dim_order = self.dim_order else: from .flows import presorted dim_order = presorted(self.root_dataset.dims, self.dim_order) return tuple( self.root_dataset.dims[i] for i in dim_order if i not in self.dim_exclude ) @property def root_dims(self): from .flows import presorted return tuple( presorted(self.root_dataset.dims, self.dim_order, self.dim_exclude) ) def __shallow_copy_extras(self): return dict( extra_funcs=self.extra_funcs, extra_vars=self.extra_vars, aux_vars=self.aux_vars, cache_dir=self.cache_dir, force_digitization=self.force_digitization, dim_order=self.dim_order, ) def __repr__(self): s = f"<{self.__module__}.{self.__class__.__name__}>" if len(self._graph.nodes): s += "\n datasets:" if self.root_node_name: s += f"\n - {self.root_node_name}" for k in self._graph.nodes: if k == self.root_node_name: continue s += f"\n - {k}" else: s += "\n datasets: none" if len(self._graph.edges): s += "\n relationships:" for e in self._graph.edges: s += f"\n - {self._get_relationship(e)!r}".replace( "<Relationship ", "" ).rstrip(">") else: s += "\n relationships: none" return s def view_relationships(self, fontname="Arial", fontsize=9): from .viz import display_svg, make_graph return display_svg(make_graph(self, fontname=fontname, fontsize=fontsize)) def _hash_features(self): h = [] if len(self._graph.nodes): if self.root_node_name: h.append(f"dataset:{self.root_node_name}") for k in self._graph.nodes: if k == self.root_node_name: continue h.append(f"dataset:{k}") else: h.append("datasets:none") if len(self._graph.edges): for e in self._graph.edges: r = f"relationship:{self._get_relationship(e)!r}".replace( "<Relationship ", "" ).rstrip(">") h.append(r) else: h.append("relationships:none") h.append(f"dim_order:{self.dim_order}") return h @property def root_node_name(self): """str: The root node for this data tree, which is only ever a parent.""" if self._root_node_name is None: for nodename in self._graph.nodes: if self._graph.in_degree(nodename) == 0: self._root_node_name = nodename break return self._root_node_name @root_node_name.setter def root_node_name(self, name): if name is None or name is False: self._root_node_name = name return if not isinstance(name, str): raise TypeError( f"root_node_name must be one of [str, None, False] not {type(name)}" ) if name not in self._graph.nodes: raise KeyError(name) self._root_node_name = name @property def root_node_name_str(self): """str: The root node for this data tree, which is only ever a parent. This method raises a ValueError if root node cannot be determined. """ if self._root_node_name is None: for nodename in self._graph.nodes: if self._graph.in_degree(nodename) == 0: self._root_node_name = nodename break if self._root_node_name is None: raise ValueError("root node cannot be determined") if self._root_node_name is False: raise ValueError("root node is False") return self._root_node_name
[docs] def add_relationship(self, *args, **kwargs): """ Add a relationship to this DataTree. The new relationship will point from a variable in one dataset to a dimension of another dataset in this tree. Both the parent and the child datasets should already have been added. Parameters ---------- *args, **kwargs All arguments are passed through to the `Relationship` contructor, unless only a single `str` argument is provided, in which case the `Relationship.from_string` class constructor is used. """ if len(args) == 1 and isinstance(args[0], Relationship): r = args[0] elif len(args) == 1 and isinstance(args[0], str): r = Relationship.from_string(args[0]) else: r = Relationship(*args, **kwargs) # check for existing relationships, don't duplicate for e in self._graph.edges: r2 = self._get_relationship(e) if r == r2: return # confirm correct pointer r.parent_data = self.finditem(r.parent_name, maybe_in=r.parent_data) self._graph.add_edge(r.parent_data, r.child_data, **r.attrs()) if self.force_digitization: self.digitize_relationships(inplace=True)
def get_relationship(self, parent, child): attrs = self._graph.edges[parent, child] return Relationship(parent_data=parent, child_data=child, **attrs) def list_relationships(self) -> list[Relationship]: """List : List all relationships defined in this tree.""" result = [] for e in self._graph.edges: result.append(self._get_relationship(e)) return result
[docs] def add_dataset(self, name, dataset, relationships=(), as_root=False): """ Add a new Dataset node to this DataTree. Parameters ---------- name : str dataset : Dataset or pandas.DataFrame Will be coerced into a `Dataset` object if it is not already in that format, using a no-copy process if possible. relationships : Tuple[str or Relationship] Also add these relationships. as_root : bool, default False Set this new node as the root of the tree, displacing any existing root. """ self._graph.add_node(name, dataset=construct(dataset)) if self.root_node_name is None or as_root: self.root_node_name = name if isinstance(relationships, str): relationships = [relationships] for r in relationships: # TODO validate relationships before adding. self.add_relationship(r) if self.force_digitization: self.digitize_relationships(inplace=True)
def add_items(self, items): from collections.abc import Mapping, Sequence if isinstance(items, Sequence): for i in items: self.add_items(i) elif isinstance(items, Mapping): if "name" in items and "dataset" in items: self.add_dataset(items["name"], items["dataset"]) preload = True else: preload = False for k, v in items.items(): if preload and k in {"name", "dataset"}: continue if k == "relationships": for r in v: self.add_relationship(r) else: self.add_dataset(k, v) else: raise ValueError("add_items requires Sequence or Mapping") @property def root_node(self): return self._graph.nodes[self.root_node_name_str] @property def root_dataset(self): return self._graph.nodes[self.root_node_name_str]["dataset"] @root_dataset.setter def root_dataset(self, x): from .dataset import Dataset if not isinstance(x, Dataset): x = construct(x) if self.root_node_name_str in self.replacement_filters: x = self.replacement_filters[self.root_node_name](x) self._graph.nodes[self.root_node_name]["dataset"] = x def _get_relationship(self, edge): return Relationship( parent_data=edge[0], child_data=edge[1], **self._graph.edges[edge] ) def __getitem__(self, item): return self.get(item) def get(self, item, default=None, broadcast=True, coords=True): """ Access variable(s) from this tree. Parameters ---------- item : str or Sequence[str] Each value can be just the name of the variable if that name is unique within the tree, or use dotted notation ('node_name.var_name') to give the node name explicitly and resolve ambiguity as necessary. default If provided, this default value is used for any missing item(s). broadcast : bool, default True Broadcast all arrays up to the dimensions of the root node in the tree. coords : bool, default True Attach coordinates from the root node of the tree to the result. Returns ------- DataArray or Dataset """ if isinstance(item, (list, tuple)): from .dataset import Dataset return Dataset( { k: self.get(k, default=default, broadcast=broadcast, coords=coords) for k in item } ) try: result = self._getitem(item, dim_names_from_top=True) except KeyError: try: result = self._getitem( item, include_blank_dims=True, dim_names_from_top=True ) except KeyError: if default is None: raise else: result = xr.DataArray(default) if self.root_node_name: root_dataset = self.root_dataset if result.dims != self.root_dims and broadcast: result, _ = xr.broadcast(result, root_dataset) if coords: add_coords = {} for i in result.dims: if i not in result.coords and i in root_dataset.coords: add_coords[i] = root_dataset.coords[i] if add_coords: result = result.assign_coords(add_coords) elif self.root_node_name is False: if "." in item: item_in, item = item.split(".", 1) base_dataset = self._graph.nodes[item_in]["dataset"] if coords: add_coords = {} for i in result.dims: if i not in result.coords and i in base_dataset.coords: add_coords[i] = base_dataset.coords[i] if add_coords: result = result.assign_coords(add_coords) return result def finditem(self, item, maybe_in=None): if maybe_in is not None and maybe_in in self._graph.nodes: dataset = self._graph.nodes[maybe_in].get("dataset", {}) if item in dataset: return maybe_in return self._getitem(item, just_node_name=True) def _getitem( self, item, include_blank_dims=False, only_dims=False, just_node_name=False, dim_names_from_top=False, ): if isinstance(item, (list, tuple)): from .dataset import Dataset return Dataset({k: self[k] for k in item}) if "." in item: item_in, item = item.split(".", 1) queue = [self.root_node_name] if self.root_node_name is False: # when root_node_name is False, we don't want to broadcast # back to the root, but instead only to the given `item_in` queue = [item_in] item_in = None else: item_in = None queue = [self.root_node_name_str] examined = set() start_from = queue[0] while len(queue): current_node = queue.pop(0) if current_node in examined: continue dataset = self._graph.nodes[current_node].get("dataset", {}) try: by_name = item in dataset and not only_dims except TypeError: by_name = False try: by_dims = not by_name and include_blank_dims and (item in dataset.dims) except TypeError: by_dims = False if (by_name or by_dims) and (item_in is None or item_in == current_node): if just_node_name: return current_node if current_node == start_from: if by_dims: return xr.DataArray( pd.RangeIndex(dataset.dims[item]), dims=item ) else: return dataset[item] else: _positions = {} _labels = {} if by_dims: if item in dataset.variables: coords = {item: dataset.variables[item]} else: coords = None result = xr.DataArray( pd.RangeIndex(dataset.dims[item]), dims=item, coords=coords, ) else: result = dataset[item] dims_in_result = set(result.dims) top_dim_names = {} for path in nx.algorithms.simple_paths.all_simple_edge_paths( self._graph, start_from, current_node ): if dim_names_from_top: e = path[0] top_dim_name = self._graph.edges[e].get("parent_name") start_dataset = self._graph.nodes[start_from]["dataset"] # deconvert digitized dim names back to native dims if ( top_dim_name not in start_dataset.dims and top_dim_name in start_dataset.variables ): if start_dataset.variables[top_dim_name].ndim == 1: top_dim_name = start_dataset.variables[ top_dim_name ].dims[0] else: top_dim_name = None path_dim = self._graph.edges[path[-1]].get("child_name") if path_dim not in dims_in_result: continue # path_indexing = self._graph.edges[path[-1]].get('indexing') t1 = None # intermediate nodes on path for e, e_next in zip(path[:-1], path[1:]): r = self._get_relationship(e) r_next = self._get_relationship(e_next) if t1 is None: t1 = self._graph.nodes[r.parent_data].get("dataset") t2 = self._graph.nodes[r.child_data].get("dataset")[ [r_next.parent_name] ] if r.indexing == "label": t1 = t2.sel( { r.child_name: _dataarray_to_numpy( t1[r.parent_name] ) } ) else: # by position t1 = t2.isel( { r.child_name: _dataarray_to_numpy( t1[r.parent_name] ) } ) # final node in path e = path[-1] r = Relationship( parent_data=e[0], child_data=e[1], **self._graph.edges[e] ) if t1 is None: t1 = self._graph.nodes[r.parent_data].get("dataset") if r.indexing == "label": _labels[r.child_name] = _dataarray_to_numpy( t1[r.parent_name] ) else: # by position _idx = _dataarray_to_numpy(t1[r.parent_name]) if not np.issubdtype(_idx.dtype, np.integer): _idx = _idx.astype(np.int64) _positions[r.child_name] = _idx if top_dim_name is not None: top_dim_names[r.child_name] = top_dim_name y = xgather(result, _positions, _labels) if len(result.dims) == 1 and len(y.dims) == 1: y = y.rename({y.dims[0]: result.dims[0]}) elif len(dims_in_result) == len(y.dims): y = y.rename({_i: _j for _i, _j in zip(y.dims, result.dims)}) if top_dim_names: y = y.rename(top_dim_names) return y else: examined.add(current_node) for _, next_up in self._graph.out_edges(current_node): if next_up not in examined: queue.append(next_up) raise KeyError(item) def get_expr(self, expression, engine="sharrow", allow_native=True): """ Access or evaluate an expression. Parameters ---------- expression : str engine : {'sharrow', 'numexpr'} The engine used to resolve expressions. allow_native : bool, default True If the expression is an array in a dataset of this tree, return that array directly. Set to false to force evaluation, which will also ensure proper broadcasting consistent with this data tree. Returns ------- DataArray """ try: if allow_native: result = self[expression] else: raise KeyError except (KeyError, IndexError): if engine == "sharrow": result = ( self.setup_flow({expression: expression}) .load_dataarray() .isel(expressions=0) ) elif engine == "numexpr": from xarray import DataArray result = DataArray( pd.eval(expression, resolvers=[self], engine="numexpr"), ) else: raise ValueError(f"unknown engine {engine}") from None return result @property def subspaces(self): """Mapping[str,Dataset] : Direct access to node Dataset objects by name.""" spaces = {} for k in self._graph.nodes: s = self._graph.nodes[k].get("dataset", None) if s is not None: spaces[k] = s return spaces def subspaces_iter(self): for k in self._graph.nodes: s = self._graph.nodes[k].get("dataset", None) if s is not None: yield (k, s) def contains_subspace(self, key) -> bool: """ Is this named Dataset in this tree's subspaces. Parameters ---------- key : str Returns ------- bool """ return key in self._graph.nodes def get_subspace(self, key, default_empty=False) -> xr.Dataset: """ Access named Dataset from this tree's subspaces. Parameters ---------- key : str default_empty : bool, default False Return an empty Dataset if the key is not found. Returns ------- xr.Dataset """ result = self._graph.nodes[key].get("dataset", None) if result is None: if default_empty: result = xr.Dataset() else: raise KeyError(key) return result def namespace_names(self): namespace = set() for spacename, spacearrays in self.subspaces_iter(): for k, _arr in spacearrays.coords.items(): namespace.add(f"__{spacename or 'base'}__{k}") for k, _arr in spacearrays.items(): if k.startswith("_s_"): namespace.add(f"__{spacename or 'base'}__{k}__indptr") namespace.add(f"__{spacename or 'base'}__{k}__indices") namespace.add(f"__{spacename or 'base'}__{k}__data") else: namespace.add(f"__{spacename or 'base'}__{k}") return namespace @property def dims(self): """Mapping from dimension names to lengths across all dataset nodes.""" dims = {} for _k, v in self.subspaces_iter(): for name, length in v.sizes.items(): if name in dims: if dims[name] != length: raise ValueError( "inconsistent dimensions\n" + self.dims_detail() ) else: dims[name] = length return xr.core.utils.Frozen(dims) sizes = dims # alternate name def dims_detail(self): """ Report on the names and sizes of dimensions in all Dataset nodes. Returns ------- str """ s = "" for k, v in self.subspaces_iter(): s += f"\n{k}:" for name, length in v.sizes.items(): s += f"\n - {name}: {length}" return s[1:] def drop_dims(self, dims, inplace=False, ignore_missing_dims=True): """ Drop dimensions from root Dataset node. Parameters ---------- dims : str or Iterable[str] One or more named dimensions to drop. inplace : bool, default False Whether to drop dimensions in-place. ignore_missing_dims : bool, default True Simply ignore any dimensions that are not present. Returns ------- DataTree Returns self if dropping inplace, otherwise returns a copy with dimensions dropped. """ if isinstance(dims, str): dims = [dims] if inplace: obj = self else: obj = self.copy() if not ignore_missing_dims: new_root_dataset = obj.root_dataset.drop_dims(dims) else: new_root_dataset = obj.root_dataset for d in dims: if d in obj.root_dataset.dims: new_root_dataset = new_root_dataset.drop_dims(d) # remove subspaces that rely on dropped dim boot_queue = set() booted = set() for (up, dn, _n), e in obj._graph.edges.items(): if up == obj.root_node_name: _analog = e.get("analog", "<missing>") if _analog in dims: boot_queue.add(dn) if _analog != "<missing>" and _analog not in new_root_dataset: boot_queue.add(dn) if e.get("parent_name", "<missing>") in dims: boot_queue.add(dn) if e.get("parent_name", "<missing>") not in new_root_dataset: boot_queue.add(dn) while boot_queue: b = boot_queue.pop() booted.add(b) for up, dn, _n in obj._graph.edges.keys(): if up == b: boot_queue.add(dn) edges_to_remove = [e for e in obj._graph.edges if e[1] in booted] obj._graph.remove_edges_from(edges_to_remove) obj._graph.remove_nodes_from(booted) obj.root_dataset = new_root_dataset obj.dim_order = tuple(x for x in self.dim_order if x not in dims) return obj def get_indexes( self, position_only=True, as_dict=True, replacements=None, use_cache=True, check_shapes=True, ): if use_cache and (position_only, as_dict) in self._cached_indexes: return self._cached_indexes[(position_only, as_dict)] if not position_only: raise NotImplementedError dims = [ d for d in self.dims if d[-1:] != "_" or (d[-1:] == "_" and d[:-1] not in self.dims) ] if replacements is not None: obj = self.replace_datasets(replacements) else: obj = self result = {} result_shape = None for k in sorted(dims): result_k = obj._getitem(k, include_blank_dims=True, only_dims=True) if result_shape is None: result_shape = result_k.shape if result_shape != result_k.shape: if check_shapes: raise ValueError( f"inconsistent index shapes {result_k.shape} v {result_shape} " f"(probably an error on {k} or {sorted(dims)[0]})" ) result[k] = result_k if as_dict: result = {k: _dataarray_to_numpy(v) for k, v in result.items()} else: result = Dataset(result) if use_cache: self._cached_indexes[(position_only, as_dict)] = result return result
[docs] def replace_datasets(self, other=None, validate=True, redigitize=True, **kwargs): """ Replace one or more datasets in the nodes of this tree. Parameters ---------- other : Mapping[str,Dataset] A dictionary of replacement datasets. validate : bool, default True Raise an error when replacing downstream datasets that are referenced by position, unless the replacement is identically sized. If validation is deactivated, and an incompatible dataset is placed in this tree, flows that rely on that relationship will give erroneous results or crash with a segfault. redigitize : bool, default True Automatically re-digitize relationships that are label-based and were previously digitized. **kwargs : Mapping[str,Dataset] Alternative format to `other`. Returns ------- DataTree A new DataTree with data replacements completed. """ replacements = {} if other is not None: replacements.update(other) replacements.update(kwargs) graph = self._graph.copy() for k in replacements: if k not in graph.nodes: raise KeyError(k) x = construct(replacements[k]) if validate: if x.sizes != graph.nodes[k]["dataset"].sizes: # when replacement dimensions do not match, check for # any upstream nodes that reference this dataset by # position... which will potentially be problematic. for e in self._graph.edges: if e[1] == k: indexing = self._graph.edges[e].get("indexing") if indexing == "position": raise ValueError( f"dimensions mismatch on " f"positionally-referenced dataset {k}: " f"receiving {x.dims} " f"expected {graph.nodes[k]['dataset'].dims}" ) # also if any dim coordinates are changing, redigitize for dim in x.dims: if dim in graph.nodes[k]["dataset"].coords: if not np.array_equal( graph.nodes[k]["dataset"].coords[dim].data, x.coords[dim].data, ): # find all edges with digitized label relationships # and cast them back to label for e in graph.edges: if e[1] == k: r = self._get_relationship(e) if r.child_name == dim and r.analog: graph.edges[e]["indexing"] = "label" graph.edges[e]["parent_name"] = r.analog if k in self.replacement_filters: x = self.replacement_filters[k](x) graph.nodes[k]["dataset"] = x result = type(self)(graph, self.root_node_name, **self.__shallow_copy_extras()) if redigitize: result.digitize_relationships(inplace=True) return result
[docs] def setup_flow( self, definition_spec, *, cache_dir=None, name=None, dtype="float32", boundscheck=False, error_model="numpy", nopython=True, fastmath=True, parallel=True, readme=None, flow_library=None, extra_hash_data=(), write_hash_audit=True, hashing_level=1, dim_exclude=None, with_root_node_name=None, ): """ Set up a new Flow for analysis using the structure of this DataTree. Parameters ---------- definition_spec : Dict[str,str] Gives the names and expressions that define the variables to create in this new `Flow`. cache_dir : Path-like, optional A location to write out generated python and numba code. If not provided, a unique temporary directory is created. name : str, optional The name of this Flow used for writing out cached files. If not provided, a unique name is generated. If `cache_dir` is given, be sure to avoid name conflicts with other flow's in the same directory. dtype : str, default "float32" The name of the numpy dtype that will be used for the output. boundscheck : bool, default False If True, boundscheck enables bounds checking for array indices, and out of bounds accesses will raise IndexError. The default is to not do bounds checking, which is faster but can produce garbage results or segfaults if there are problems, so try turning this on for debugging if you are getting unexplained errors or crashes. error_model : {'numpy', 'python'}, default 'numpy' The error_model option controls the divide-by-zero behavior. Setting it to ‘python’ causes divide-by-zero to raise exception like CPython. Setting it to ‘numpy’ causes divide-by-zero to set the result to +/-inf or nan. nopython : bool, default True Compile using numba's `nopython` mode. Provided for debugging only, as there's little point in turning this off for production code, as all the speed benefits of sharrow will be lost. fastmath : bool, default True If true, fastmath enables the use of "fast" floating point transforms, which can improve performance but can result in tiny distortions in results. See numba docs for details. parallel : bool, default True Enable or disable parallel computation for certain functions. readme : str, optional A string to inject as a comment at the top of the flow Python file. flow_library : Mapping[str,Flow], optional An in-memory cache of precompiled Flow objects. Using this can result in performance improvements when repeatedly using the same definitions. extra_hash_data : Tuple[Hashable], optional Additional data used for generating the flow hash. Useful to prevent conflicts when using a flow_library with multiple similar flows. write_hash_audit : bool, default True Writes a hash audit log into a comment in the flow Python file, for debugging purposes. hashing_level : int, default 1 Level of detail to write into flow hashes. Increase detail to avoid hash conflicts for similar flows. Level 2 adds information about names used in expressions and digital encodings to the flow hash, which prevents conflicts but requires more pre-computation to generate the hash. dim_exclude : Collection[str], optional Exclude these root dataset dimensions from this flow. Returns ------- Flow """ from .flows import Flow return Flow( self, definition_spec, cache_dir=cache_dir or self.cache_dir, name=name, dtype=dtype, boundscheck=boundscheck, nopython=nopython, fastmath=fastmath, parallel=parallel, readme=readme, flow_library=flow_library, extra_hash_data=extra_hash_data, hashing_level=hashing_level, error_model=error_model, write_hash_audit=write_hash_audit, dim_order=self.dim_order, dim_exclude=dim_exclude, with_root_node_name=with_root_node_name, )
def get_named_array(self, mangled_name): if mangled_name[:2] != "__": raise KeyError(mangled_name) name1, name2 = mangled_name[2:].split("__", 1) if name1 == "aux_var": return self.aux_vars[name2] dataset = self._graph.nodes[name1].get("dataset") if name2.startswith("_s_"): if name2.endswith("__data"): return dataset[name2[:-6]].data.data elif name2.endswith("__indptr"): return dataset[name2[:-8]].data.indptr elif name2.endswith("__indices"): return dataset[name2[:-9]].data.indices try: _d = dataset[name2] except KeyError as err: raise KeyError(f"{name1}.{name2}") from err else: return _dataarray_to_numpy(_d) _BY_OFFSET = "digitizedOffset"
[docs] def digitize_relationships(self, inplace=False, redigitize=True): """ Convert all label-based relationships into position-based. Parameters ---------- inplace : bool, default False redigitize : bool, default True Re-compute position-based relationships from labels, even if the relationship had previously been digitized. Returns ------- DataTree or None Only returns a copy if not digitizing in-place. """ if inplace: obj = self else: obj = self.copy() for e in obj._graph.edges: r = obj._get_relationship(e) if redigitize and r.analog: p_dataset = obj._graph.nodes[r.parent_data].get("dataset", None) if p_dataset is not None: if r.parent_name not in p_dataset: r.indexing = "label" r.parent_name = r.analog if r.indexing == "label": p_dataset = obj._graph.nodes[r.parent_data].get("dataset", None) if p_dataset is None: raise ValueError(f"no dataset found for {r.parent_data}") c_dataset = obj._graph.nodes[r.child_data].get("dataset", None) if c_dataset is None: raise ValueError(f"no dataset found for {r.child_data}") upstream = p_dataset[r.parent_name] downstream = c_dataset[r.child_name] upstream_is_categorical = ( isinstance(upstream, xr.DataArray) and upstream.cat.is_categorical() ) # check if both upstream and downstream are categoricals with the same categories if upstream_is_categorical: if np.array_equal(upstream.cat.category_array(), downstream): # if so, we can just use the codes offsets = upstream if (offsets < 0).any(): raise ValueError( f"detected missing values in digitizing {r.parent_data}.{r.parent_name}" ) else: raise ValueError( f"upstream ({r.parent_data}.{r.parent_name}) and " f"downstream ({r.child_data}.{r.child_name}) categoricals " f"have different categories" ) else: # vectorize version mapper = { i: j for (j, i) in enumerate(_dataarray_to_numpy(downstream)) } def mapper_get(x, mapper=mapper): return mapper.get(x, 0) if upstream.size: offsets = xr.apply_ufunc(np.vectorize(mapper_get), upstream) else: offsets = xr.DataArray([], dims=["index"]) if offsets.dtype.kind != "i": warnings.warn( f"detected missing values in digitizing {r.parent_data}.{r.parent_name}", stacklevel=2, ) # candidate name for write back r_parent_name_new = ( f"{self._BY_OFFSET}{r.parent_name}_{r.child_data}_{r.child_name}" ) # it is common to have mirrored offsets in various dimensions. # we'd like to retain only the same data in memory once, so we'll # check if these offsets match any existing ones and if so just # point to that memory. for k in p_dataset: if isinstance(k, str) and k.startswith(self._BY_OFFSET): if p_dataset[k].equals(offsets): # we found a match, so we'll assign this name to # the match's memory storage instead of replicating it. obj._graph.nodes[r.parent_data]["dataset"] = ( p_dataset.assign({r_parent_name_new: p_dataset[k]}) ) # r_parent_name_new = k break else: # no existing offset arrays match, make this new one obj._graph.nodes[r.parent_data]["dataset"] = p_dataset.assign( {r_parent_name_new: offsets} ) obj._graph.edges[e].update( dict( parent_name=r_parent_name_new, indexing="position", analog=r.parent_name, ) ) if not inplace: return obj
@property def relationships_are_digitized(self): """Bool : Whether all relationships are digital (by position).""" for e in self._graph.edges: r = self._get_relationship(e) if r.indexing != "position": return False return True def _arg_tokenizer( self, spacename, spacearray, spacearrayname, exclude_dims=None, blends=None ): if blends is None: blends = {} if spacename == self.root_node_name: root_dataset = self.root_dataset from .flows import presorted root_dims = list(presorted(root_dataset.dims, self.dim_order, exclude_dims)) if isinstance(spacearray, str): from_dims = root_dataset[spacearray].dims else: from_dims = spacearray.dims return ( tuple( ast.parse(f"_arg{root_dims.index(dim):02}", mode="eval").body for dim in from_dims ), blends, ) if isinstance(spacearray, str): spacearray_ = self._graph.nodes[spacename]["dataset"][spacearray] else: spacearray_ = spacearray from_dims = spacearray_.dims offset_source = spacearray_.attrs.get("digital_encoding", {}).get( "offset_source", None ) if offset_source is not None: from_dims = self._graph.nodes[spacename]["dataset"][offset_source].dims tokens = [] n_missing_tokens = 0 for dimname in from_dims: found_token = False for e in self._graph.in_edges(spacename, keys=True): this_dim_name = self._graph.edges[e]["child_name"] retarget = None if dimname != this_dim_name: retarget = self._graph.nodes[spacename][ "dataset" ].redirection.target(this_dim_name) if dimname != retarget: continue parent_name = self._graph.edges[e]["parent_name"] parent_data = e[0] upside_ast, blends_ = self._arg_tokenizer( parent_data, parent_name, spacearrayname=spacearrayname, exclude_dims=exclude_dims, blends=blends, ) try: upside = ", ".join(unparse(t) for t in upside_ast) except: # noqa: E722 if self.root_node_name is False: upside = None else: print(f"{parent_data=}") print(f"{parent_name=}") print(f"{spacearrayname=}") print(f"{exclude_dims=}") print(f"{blends=}") for t in upside_ast: str_t = str(t) if len(str_t) < 2000: print(f"t:{str_t}") else: print(f"t:{str_t[:200]}...") raise if upside is not None: # check for redirection target if retarget is not None: tokens.append( f"__{spacename}___digitized_{retarget}_of_{this_dim_name}[__{parent_data}__{parent_name}[{upside}]]" ) else: tokens.append(f"__{parent_data}__{parent_name}[{upside}]") found_token = True break if not found_token: if dimname in self.subspaces[spacename].indexes: if self.root_node_name is False: tokens.append(False) else: ix = self.subspaces[spacename].indexes[dimname] ix = {i: n for n, i in enumerate(ix)} tokens.append(ix) n_missing_tokens += 1 elif dimname.endswith("_indices") or dimname.endswith("_indptr"): tokens.append(None) # this dimension corresponds to a blender if n_missing_tokens > 1: raise ValueError("at most one missing dimension is allowed") result = [] for t in tokens: if isinstance(t, str): # print(f"TOKENIZE: {spacename=} {spacearray=} {t}") result.append(ast.parse(t, mode="eval").body) else: result.append(t) return tuple(result), blends @property def coords(self): return self.root_dataset.coords def get_index(self, dim): for _spacename, subspace in self.subspaces.items(): if dim in subspace.coords: return subspace.indexes[dim] def copy(self): return type(self)( self._graph.copy(), self.root_node_name, **self.__shallow_copy_extras() ) def all_var_names(self, uniquify=False, _duplicated_names=None): ordered_names = [] require_unique = _duplicated_names is None and not uniquify need_second_pass = _duplicated_names is None and uniquify print(f"{require_unique=}") discovered_names = set() duplicated_names = _duplicated_names or set() for spacename, space in self.subspaces_iter(): for name in space.variables: if name in duplicated_names: if require_unique: raise ValueError(f"duplicate name {name}") elif uniquify: ordered_names.append(f"{spacename}.{name}") else: ordered_names.append(name) elif name in discovered_names: duplicated_names.add(name) if require_unique: raise ValueError(f"duplicate name {name}") else: ordered_names.append(name) else: discovered_names.add(name) ordered_names.append(name) if need_second_pass: return self.all_var_names(uniquify=True, _duplicated_names=duplicated_names) return ordered_names def merged_dataset(self, columns=None, uniquify=False): if columns is None: columns = self.all_var_names(uniquify=uniquify) if len(self.root_dataset.dims) > 1: raise NotImplementedError("only single dim root datasets") dim_name = self.root_dataset.single_dim.dim_name vx = [] coords = {} for k in columns: v = self._getitem(k).single_dim.rename(dim_name) if v.name == v.dims[0]: coords[v.name] = v else: vx.append(v) result = xr.merge(vx, compat="override", join="override") if coords: result.assign_coords(coords) return result