Source code for sharrow.flows

import ast
import base64
import hashlib
import importlib
import inspect
import io
import logging
import os
import re
import sys
import textwrap
import time
import warnings

import numba as nb
import numpy as np
import pandas as pd
import xarray as xr

from ._infer_version import __version__
from .aster import expression_for_numba, extract_all_name_tokens, extract_names_2
from .filewrite import blacken, rewrite
from .relationships import DataTree
from .table import Table

logger = logging.getLogger("sharrow")


class CacheMissWarning(UserWarning):
    pass


well_known_names = {
    "nb",
    "np",
    "pd",
    "xr",
    "pa",
    "log",
    "exp",
    "log1p",
    "expm1",
    "max",
    "min",
    "piece",
    "hard_sigmoid",
    "transpose_leading",
    "clip",
    "get",
}


def one_based(n):
    return pd.RangeIndex(1, n + 1)


def zero_based(n):
    return pd.RangeIndex(0, n)


def clean(s):
    """
    Convert any string into a similar python identifier.

    If any modification of the string is made, or if the string
    is longer than 120 characters, it is truncated and a hash of the
    original string is added to the end, to ensure every
    string maps to a unique cleaned name.

    Parameters
    ----------
    s : str

    Returns
    -------
    cleaned : str
    """
    if not isinstance(s, str):
        s = f"{type(s)}-{s}"
    cleaned = re.sub(r"\W|^(?=\d)", "_", s)
    if cleaned != s or len(cleaned) > 120:
        # digest size 15 creates a 24 character base32 string
        h = base64.b32encode(
            hashlib.blake2b(s.encode(), digest_size=15).digest()
        ).decode()
        cleaned = f"{cleaned[:90]}_{h}"
    return cleaned


def presorted(sortable, presort=None, exclude=None):
    """
    Sort a collection, with certain items appearing first.

    Parameters
    ----------
    sortable : Collection
        Elements to sort.
    presort : Iterable, optional
        Pre-sorted elements, which are yielded first, in this order,
        if they appear in `sortable`.

    Yields
    ------
    Any
        The elements of sortable.
    """
    queue = set(sortable)
    if presort is not None:
        for j in presort:
            if j in queue:
                if exclude is None or j not in exclude:
                    yield j
                queue.remove(j)
    for i in sorted(queue):
        if exclude is None or i not in exclude:
            yield i


def _flip_flop_def(v):
    if isinstance(v, str) and "# sharrow:" in v:
        return v.split("# sharrow:", 1)[1].strip()
    else:
        return v


well_known_names |= {
    "_args",
    "_inputs",
    "_outputs",
}
ARG_NAMES = {f"_arg{n:02}" for n in range(100)}
well_known_names |= ARG_NAMES


def filter_name_tokens(expr, matchable_names=None):
    name_tokens = extract_all_name_tokens(expr)
    arg_tokens = name_tokens & ARG_NAMES
    name_tokens -= well_known_names
    if matchable_names:
        name_tokens &= matchable_names
    return name_tokens, arg_tokens


class ExtractOptionalGetTokens(ast.NodeVisitor):
    def __init__(self, from_names):
        self.optional_get_tokens = set()
        self.required_get_tokens = set()
        self.from_names = from_names

    def visit_Call(self, node):
        if isinstance(node.func, ast.Attribute):
            if node.func.attr == "get":
                if isinstance(node.func.value, ast.Name):
                    if node.func.value.id in self.from_names:
                        if len(node.args) == 1:
                            if isinstance(node.args[0], ast.Constant):
                                if len(node.keywords) == 0:
                                    self.required_get_tokens.add(
                                        (node.func.value.id, node.args[0].value)
                                    )
                                elif (
                                    len(node.keywords) == 1
                                    and node.keywords[0].arg == "default"
                                ):
                                    self.optional_get_tokens.add(
                                        (node.func.value.id, node.args[0].value)
                                    )
                                else:
                                    raise ValueError(
                                        f"{node.func.value.id}.get with unexpected keyword arguments"
                                    )
                        if len(node.args) == 2:
                            if isinstance(node.args[0], ast.Constant):
                                self.optional_get_tokens.add(
                                    (node.func.value.id, node.args[0].value)
                                )
                        if len(node.args) > 2:
                            raise ValueError(
                                f"{node.func.value.id}.get with more than 2 positional arguments"
                            )
        self.generic_visit(node)

    def check(self, node):
        if isinstance(node, str):
            node = ast.parse(node)
        if isinstance(node, ast.AST):
            self.visit(node)
        else:
            try:
                node_iter = iter(node)
            except TypeError:
                pass
            else:
                for i in node_iter:
                    self.check(i)
        return self.optional_get_tokens


def coerce_to_range_index(idx):
    if isinstance(idx, pd.RangeIndex):
        return idx
    if isinstance(idx, (pd.Int64Index, pd.Float64Index, pd.UInt64Index)):
        if idx.is_monotonic_increasing and idx[-1] - idx[0] == idx.size - 1:
            return pd.RangeIndex(idx[0], idx[0] + idx.size)
    return idx


FUNCTION_TEMPLATE = """

# {init_expr}
@nb.jit(
    cache=False,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def {fname}(
    {argtokens}
    _outputs,
    {nametokens}
):
    return {expr}

"""


IRUNNER_1D_TEMPLATE = """
@nb.jit(
    cache=True,
    parallel=True,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def irunner(
    argshape,
    {joined_namespace_names}
    dtype=np.{dtype},
    mask=None,
):
    result = np.empty((argshape[0], {len_self_raw_functions}), dtype=dtype)
    if mask is not None:
        assert mask.ndim == 1
        assert mask.shape[0] == argshape[0]
    for j0 in nb.prange(argshape[0]):
        if mask is not None:
            if not mask[j0]:
                result[j0, :] = np.nan
                continue
        linemaker(result[j0], j0, {joined_namespace_names})
    return result
"""

IRUNNER_2D_TEMPLATE = """
@nb.jit(
    cache=True,
    parallel=True,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def irunner(
    argshape,
    {joined_namespace_names}
    dtype=np.{dtype},
    mask=None,
):
    result = np.empty((argshape[0], argshape[1], {len_self_raw_functions}), dtype=dtype)
    if mask is not None:
        assert mask.ndim == 2
        assert mask.shape[0] == argshape[0]
        assert mask.shape[1] == argshape[1]
    for j0 in nb.prange(argshape[0]):
        for j1 in range(argshape[1]):
            if mask is not None:
                if not mask[j0, j1]:
                    result[j0, j1, :] = np.nan
            linemaker(result[j0, j1], j0, j1, {joined_namespace_names})
    return result
"""

IDOTTER_1D_TEMPLATE = """
@nb.jit(
    cache=True,
    parallel=True,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def idotter(
    argshape,
    {joined_namespace_names}
    dtype=np.{dtype},
    dotarray=None,
):
    if dotarray is None:
        raise ValueError("dotarray cannot be None")
    assert dotarray.ndim == 2
    result = np.empty((argshape[0], dotarray.shape[1]), dtype=dtype)
    if argshape[0] > 1000:
        for j0 in nb.prange(argshape[0]):
            intermediate = np.zeros({len_self_raw_functions}, dtype=dtype)
            {meta_code_stack_dot}
            np.dot(intermediate, dotarray, out=result[j0,:])
    else:
        intermediate = np.zeros({len_self_raw_functions}, dtype=dtype)
        for j0 in range(argshape[0]):
            {meta_code_stack_dot}
            np.dot(intermediate, dotarray, out=result[j0,:])
    return result
"""

IDOTTER_2D_TEMPLATE = """
@nb.jit(
    cache=True,
    parallel=True,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def idotter(
    argshape,
    {joined_namespace_names}
    dtype=np.{dtype},
    dotarray=None,
):
    if dotarray is None:
        raise ValueError("dotarray cannot be None")
    assert dotarray.ndim == 2
    result = np.empty((argshape[0], argshape[1], dotarray.shape[1]), dtype=dtype)
    if argshape[0] > 1000:
        for j0 in nb.prange(argshape[0]):
          for j1 in range(argshape[1]):
            intermediate = np.zeros({len_self_raw_functions}, dtype=dtype)
            {meta_code_stack_dot}
            np.dot(intermediate, dotarray, out=result[j0,j1,:])
    else:
        intermediate = np.zeros({len_self_raw_functions}, dtype=dtype)
        for j0 in range(argshape[0]):
          for j1 in range(argshape[1]):
            {meta_code_stack_dot}
            np.dot(intermediate, dotarray, out=result[j0,j1,:])
    return result
"""

ILINER_1D_TEMPLATE = """
@nb.jit(
    cache=False,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def linemaker(
    intermediate, j0,
    {joined_namespace_names}
):
            {meta_code_stack_dot}

"""

ILINER_2D_TEMPLATE = """
@nb.jit(
    cache=False,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def linemaker(
    intermediate, j0, j1,
    {joined_namespace_names}
):
            {meta_code_stack_dot}

"""


MNL_GENERIC_TEMPLATE = """
@nb.jit(
    cache=True,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def _sample_choices_maker(
        prob_array,
        random_array,
        out_choices,
        out_choice_probs,
):
    '''
    Random sample of alternatives.

    Parameters
    ----------
    prob_array : array of float, shape (n_alts)
    random_array : array of float, shape (n_samples)
    out_choices : array of int, shape (n_samples) output
    out_choice_probs : array of float, shape (n_samples) output

    '''
    sample_size = random_array.size
    n_alts = prob_array.size

    random_points = np.sort(random_array)
    a = 0
    s = 0
    unique_s = 0
    z = 0.0
    for a in range(n_alts):
        z += prob_array[a]
        while s < sample_size and z > random_points[s]:
            out_choices[s] = a
            out_choice_probs[s] = prob_array[a]
            s += 1
        if s >= sample_size:
            break
    if s < sample_size:
        # rare condition, only if a random point is greater than 1 (a bug)
        # or if the sum of probabilities is less than 1 and a random point
        # is greater than that sum, which due to the limits of numerical
        # precision can technically happen
        a = n_alts-1
        while prob_array[a] < 1e-30 and a > 0:
            # slip back to the last choice with non-trivial prob
            a -= 1
        while s < sample_size:
            out_choices[s] = a
            out_choice_probs[s] = prob_array[a]
            s += 1



@nb.jit(
    cache=True,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def _sample_choices_maker_counted(
        prob_array,
        random_array,
        out_choices,
        out_choice_probs,
        out_pick_count,
):
    '''
    Random sample of alternatives.

    Parameters
    ----------
    prob_array : array of float, shape (n_alts)
    random_array : array of float, shape (n_samples)
    out_choices : array of int, shape (n_samples) output
    out_choice_probs : array of float, shape (n_samples) output
    out_pick_count : array of int, shape (n_samples) output

    '''
    sample_size = random_array.size
    n_alts = prob_array.size

    random_points = np.sort(random_array)
    a = 0
    s = 0
    unique_s = -1
    z = 0.0
    out_pick_count[:] = 0
    for a in range(n_alts):
        z += prob_array[a]
        if s < sample_size and z > random_points[s]:
            unique_s += 1
        while s < sample_size and z > random_points[s]:
            out_choices[unique_s] = a
            out_choice_probs[unique_s] = prob_array[a]
            out_pick_count[unique_s] += 1
            s += 1
        if s >= sample_size:
            break
    if s < sample_size:
        # rare condition, only if a random point is greater than 1 (a bug)
        # or if the sum of probabilities is less than 1 and a random point
        # is greater than that sum, which due to the limits of numerical
        # precision can technically happen
        a = n_alts-1
        while prob_array[a] < 1e-30 and a > 0:
            # slip back to the last choice with non-trivial prob
            a -= 1
        if out_choices[unique_s] != a:
            unique_s += 1
        while s < sample_size:
            out_choices[unique_s] = a
            out_choice_probs[unique_s] = prob_array[a]
            out_pick_count[unique_s] += 1
            s += 1

"""

MNL_1D_TEMPLATE = (
    MNL_GENERIC_TEMPLATE
    + """

logit_ndims = 1

@nb.jit(
    cache=True,
    parallel=True,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def mnl_transform_plus1d(
    argshape,
    {joined_namespace_names}
    dtype=np.{dtype},
    dotarray=None,
    random_draws=None,
    pick_counted=False,
    logsums=False,
    choice_dtype=np.int32,
    pick_count_dtype=np.int32,
    mask=None,
):
    if dotarray is None:
        raise ValueError("dotarray cannot be None")
    assert dotarray.ndim == 2
    if mask is not None:
        assert mask.ndim == 1
        assert mask.shape[0] == argshape[0]
    result = np.full((argshape[0], random_draws.shape[1]), -1, dtype=choice_dtype)
    result_p = np.zeros((argshape[0], random_draws.shape[1]), dtype=dtype)
    if pick_counted:
        pick_count = np.zeros((argshape[0], random_draws.shape[1]), dtype=pick_count_dtype)
    else:
        pick_count = np.zeros((argshape[0], 0), dtype=pick_count_dtype)
    if logsums:
        _logsums = np.zeros((argshape[0], ), dtype=dtype)
    else:
        _logsums = np.zeros((0, ), dtype=dtype)
    for j0 in nb.prange(argshape[0]):
            if mask is not None:
                if not mask[j0]:
                    continue
            intermediate = np.zeros({len_self_raw_functions}, dtype=dtype)
            {meta_code_stack_dot}
            dotprod = np.dot(intermediate, dotarray)
            shifter = np.max(dotprod)
            partial = np.exp(dotprod - shifter)
            local_sum = np.sum(partial)
            partial /= local_sum
            if logsums:
                _logsums[j0] = np.log(local_sum) + shifter
            if pick_counted:
                _sample_choices_maker_counted(partial, random_draws[j0], result[j0], result_p[j0], pick_count[j0])
            else:
                _sample_choices_maker(partial, random_draws[j0], result[j0], result_p[j0])
    return result, result_p, pick_count, _logsums

"""
)
# @nb.jit(
#     cache=True,
#     parallel=True,
#     error_model='{error_model}',
#     boundscheck={boundscheck},
#     nopython={nopython},
#     fastmath={fastmath})
# def mnl_transform_plus1d(
#     argshape,
#     {joined_namespace_names}
#     dtype=np.{dtype},
#     dotarray=None,
#     random_draws=None,
#     pick_counted=False,
#     logsums=False,
#     choice_dtype=np.int32,
#     pick_count_dtype=np.int32,
# ):
#     if dotarray is None:
#         raise ValueError("dotarray cannot be None")
#     assert dotarray.ndim == 2
#     result = np.full((argshape[0], argshape[1], random_draws.shape[1]), -1, dtype=choice_dtype)
#     result_p = np.zeros((argshape[0], argshape[1], random_draws.shape[1]), dtype=dtype)
#     if pick_counted:
#         pick_count = np.zeros((argshape[0], argshape[1], random_draws.shape[1]), dtype=pick_count_dtype)
#     else:
#         pick_count = np.zeros((argshape[0], argshape[1], 0), dtype=pick_count_dtype)
#     if logsums:
#         _logsums = np.zeros((argshape[0], argshape[1], ), dtype=dtype)
#     else:
#         _logsums = np.zeros((0, 0), dtype=dtype)
#     for j0 in nb.prange(argshape[0]):
#         for k0 in range(argshape[1]):
#             intermediate = np.zeros({len_self_raw_functions}, dtype=dtype)
#             {meta_code_stack_dot}
#             dotprod = np.dot(intermediate, dotarray)
#             shifter = np.max(dotprod)
#             partial = np.exp(dotprod - shifter)
#             local_sum = np.sum(partial)
#             partial /= local_sum
#             if logsums:
#                 _logsums[j0,k0] = np.log(local_sum) + shifter
#             if pick_counted:
#                 _sample_choices_maker_counted(
#                   partial, random_draws[j0,k0], result[j0,k0], result_p[j0,k0], pick_count[j0,k0]
#                 )
#             else:
#                 _sample_choices_maker(partial, random_draws[j0,k0], result[j0,k0], result_p[j0,k0])
#     return result, result_p, pick_count, _logsums


MNL_2D_TEMPLATE = (
    MNL_GENERIC_TEMPLATE
    + """

logit_ndims = 2

@nb.jit(
    cache=True,
    parallel=True,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def mnl_transform(
    argshape,
    {joined_namespace_names}
    dtype=np.{dtype},
    dotarray=None,
    random_draws=None,
    pick_counted=False,
    logsums=False,
    choice_dtype=np.int32,
    pick_count_dtype=np.int32,
    mask=None,
):
    if dotarray is None:
        raise ValueError("dotarray cannot be None")
    assert dotarray.ndim == 2
    assert dotarray.shape[1] == 1
    dotarray = dotarray.reshape(-1)
    if random_draws is None:
        raise ValueError("random_draws cannot be None")
    assert random_draws.ndim == 2
    assert random_draws.shape[0] == argshape[0]
    if mask is not None:
        assert mask.ndim == 1
        assert mask.shape[0] == argshape[0]

    result = np.full((argshape[0], random_draws.shape[1]), -1, dtype=choice_dtype)
    result_p = np.zeros((argshape[0], random_draws.shape[1]), dtype=dtype)
    if pick_counted:
        pick_count = np.zeros((argshape[0], random_draws.shape[1]), dtype=pick_count_dtype)
    else:
        pick_count = np.zeros((argshape[0], 0), dtype=pick_count_dtype)
    if logsums:
        _logsums = np.zeros((argshape[0], ), dtype=dtype)
    else:
        _logsums = np.zeros((0, ), dtype=dtype)
    for j0 in nb.prange(argshape[0]):
        if mask is not None:
            if not mask[j0]:
                continue
        partial = np.zeros(argshape[1], dtype=dtype)
        intermediate = np.zeros({len_self_raw_functions}, dtype=dtype)
        shifter = -99999
        for j1 in range(argshape[1]):
            intermediate[:] = 0
            {meta_code_stack_dot}
            v = partial[j1] = np.dot(intermediate, dotarray)
            if v > shifter:
                shifter = v
        for j1 in range(argshape[1]):
            partial[j1] = np.exp(partial[j1] - shifter)
        local_sum = np.sum(partial)
        if logsums:
            _logsums[j0] = np.log(local_sum) + shifter
            if logsums == 1:
              continue
        partial /= local_sum
        if pick_counted:
            _sample_choices_maker_counted(partial, random_draws[j0], result[j0], result_p[j0], pick_count[j0])
        else:
            _sample_choices_maker(partial, random_draws[j0], result[j0], result_p[j0])
    return result, result_p, pick_count, _logsums


@nb.jit(
    cache=True,
    parallel=True,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def mnl_transform_plus1d(
    argshape,
    {joined_namespace_names}
    dtype=np.{dtype},
    dotarray=None,
    random_draws=None,
    pick_counted=False,
    logsums=False,
    choice_dtype=np.int32,
    pick_count_dtype=np.int32,
    mask=None,
):
    if dotarray is None:
        raise ValueError("dotarray cannot be None")
    assert dotarray.ndim == 2
    assert dotarray.shape[1] >= 1
    if random_draws is None:
        raise ValueError("random_draws cannot be None")
    assert random_draws.ndim == 3
    assert random_draws.shape[0] == argshape[0]
    assert random_draws.shape[1] == argshape[1]
    if mask is not None:
        assert mask.ndim == 2
        assert mask.shape[0] == argshape[0]
        assert mask.shape[1] == argshape[1]

    result = np.full((argshape[0], argshape[1], random_draws.shape[2]), -1, dtype=choice_dtype)
    result_p = np.zeros((argshape[0], argshape[1], random_draws.shape[2]), dtype=dtype)
    if pick_counted:
        pick_count = np.zeros((argshape[0], argshape[1], random_draws.shape[2]), dtype=pick_count_dtype)
    else:
        pick_count = np.zeros((argshape[0], argshape[1], 0), dtype=pick_count_dtype)
    if logsums:
        _logsums = np.zeros((argshape[0], argshape[1], ), dtype=dtype)
    else:
        _logsums = np.zeros((0, 0), dtype=dtype)
    for j0 in nb.prange(argshape[0]):
        partial = np.zeros(dotarray.shape[1], dtype=dtype)
        for j1 in range(argshape[1]):
            if mask is not None:
                if not mask[j0,j1]:
                    continue
            intermediate = np.zeros({len_self_raw_functions}, dtype=dtype)
            {meta_code_stack_dot}
            partial = np.dot(intermediate, dotarray, out=partial)
            shifter = np.max(partial)
            partial = np.exp(partial - shifter)
            local_sum = np.sum(partial)
            if logsums:
                _logsums[j0,j1] = np.log(local_sum) + shifter
                if logsums == 1:
                    continue
            partial /= local_sum
            if pick_counted:
                _sample_choices_maker_counted(
                    partial, random_draws[j0,j1], result[j0,j1], result_p[j0,j1], pick_count[j0,j1]
                )
            else:
                _sample_choices_maker(partial, random_draws[j0,j1], result[j0,j1], result_p[j0,j1])
    return result, result_p, pick_count, _logsums

"""
)

NL_1D_TEMPLATE = """

from sharrow.nested_logit import _utility_to_probability

@nb.jit(
    cache=True,
    parallel=True,
    error_model='{error_model}',
    boundscheck={boundscheck},
    nopython={nopython},
    fastmath={fastmath},
    nogil={nopython})
def nl_transform(
    argshape,
    {joined_namespace_names}
    dtype=np.{dtype},
    dotarray=None,
    random_draws=None,
    pick_counted=False,
    logsums=False,
    n_nodes=0,
    n_alts=0,
    edges_up=None,  # int input shape=[edges]
    edges_dn=None,  # int input shape=[edges]
    mu_params=None,  # float input shape=[nests]
    start_slots=None,  # int input shape=[nests]
    len_slots=None,  # int input shape=[nests]
    choice_dtype=np.int32,
    pick_count_dtype=np.int32,
    mask=None,
):
    if dotarray is None:
        raise ValueError("dotarray cannot be None")
    assert dotarray.ndim == 2
    if mask is not None:
        assert mask.ndim == 1
        assert mask.shape[0] == argshape[0]
    if logsums == 1:
        result = np.full((0, random_draws.shape[1]), -1, dtype=choice_dtype)
        result_p = np.zeros((0, random_draws.shape[1]), dtype=dtype)
    else:
        result = np.full((argshape[0], random_draws.shape[1]), -1, dtype=choice_dtype)
        result_p = np.zeros((argshape[0], random_draws.shape[1]), dtype=dtype)
    if pick_counted:
        pick_count = np.zeros((argshape[0], random_draws.shape[1]), dtype=pick_count_dtype)
    else:
        pick_count = np.zeros((argshape[0], 0), dtype=pick_count_dtype)
    if logsums:
        _logsums = np.zeros((argshape[0], ), dtype=dtype)
    else:
        _logsums = np.zeros((0, ), dtype=dtype)
    for j0 in nb.prange(argshape[0]):
            if mask is not None:
                if not mask[j0]:
                    continue
            intermediate = np.zeros({len_self_raw_functions}, dtype=dtype)
            {meta_code_stack_dot}
            utility = np.zeros(n_nodes, dtype=dtype)
            utility[:n_alts] = np.dot(intermediate, dotarray)
            if logsums == 1:
                logprob = np.zeros(0, dtype=dtype)
                probability = np.zeros(0, dtype=dtype)
            else:
                logprob = np.zeros(n_nodes, dtype=dtype)
                probability = np.zeros(n_nodes, dtype=dtype)
            _utility_to_probability(
                n_alts,
                edges_up,  # int input shape=[edges]
                edges_dn,  # int input shape=[edges]
                mu_params,  # float input shape=[nests]
                start_slots,  # int input shape=[nests]
                len_slots,  # int input shape=[nests]
                (logsums==1),
                utility,  # float output shape=[nodes]
                logprob,  # float output shape=[nodes]
                probability,  # float output shape=[nodes]
            )
            if logsums:
                _logsums[j0] = utility[-1]
            if logsums != 1:
                if pick_counted:
                    _sample_choices_maker_counted(
                        probability[:n_alts], random_draws[j0], result[j0], result_p[j0], pick_count[j0]
                    )
                else:
                    _sample_choices_maker(probability[:n_alts], random_draws[j0], result[j0], result_p[j0])
    return result, result_p, pick_count, _logsums

"""


def zero_size_to_None(x):
    if x is not None and x.size == 0:
        return None
    return x


def squeeze(x, *args):
    x = zero_size_to_None(x)
    if x is None:
        return None
    try:
        return np.squeeze(x, *args)
    except Exception:
        if hasattr(x, "shape"):
            logger.error(f"failed to squeeze {args!r} from array of shape {x.shape}")
        else:
            logger.error(f"failed to squeeze {args!r} from array of unknown shape")
        raise


[docs] class Flow: """ A prepared data flow. Parameters ---------- tree : DataTree The tree from whence the output will be constructed. defs : Mapping[str,str] Gives the names and definitions for the variables to create in the generated output. error_model : {'numpy', 'python'}, default 'numpy' The error_model option controls the divide-by-zero behavior. Setting it to ‘python’ causes divide-by-zero to raise exception like CPython. Setting it to ‘numpy’ causes divide-by-zero to set the result to +/-inf or nan. cache_dir : Path-like, optional A location to write out generated python and numba code. If not provided, a unique temporary directory is created. name : str, optional The name of this Flow used for writing out cached files. If not provided, a unique name is generated. If `cache_dir` is given, be sure to avoid name conflicts with other flow's in the same directory. dtype : str, default "float32" The name of the numpy dtype that will be used for the output. boundscheck : bool, default False If True, boundscheck enables bounds checking for array indices, and out of bounds accesses will raise IndexError. The default is to not do bounds checking, which is faster but can produce garbage results or segfaults if there are problems, so try turning this on for debugging if you are getting unexplained errors or crashes. nopython : bool, default True Compile using numba's `nopython` mode. Provided for debugging only, as there's little point in turning this off for production code, as all the speed benefits of sharrow will be lost. fastmath : bool, default True If true, fastmath enables the use of "fast" floating point transforms, which can improve performance but can result in tiny distortions in results. See numba docs for details. parallel : bool, default True Enable or disable parallel computation for certain functions. readme : str, optional A string to inject as a comment at the top of the flow Python file. flow_library : Mapping[str,Flow], optional An in-memory cache of precompiled Flow objects. Using this can result in performance improvements when repeatedly using the same definitions. extra_hash_data : Tuple[Hashable], optional Additional data used for generating the flow hash. Useful to prevent conflicts when using a flow_library with multiple similar flows. write_hash_audit : bool, default True Writes a hash audit log into a comment in the flow Python file, for debugging purposes. hashing_level : int, default 1 Level of detail to write into flow hashes. Increase detail to avoid hash conflicts for similar flows. """ def __new__( cls, tree, defs, error_model="numpy", cache_dir=None, name=None, dtype="float32", boundscheck=False, nopython=True, fastmath=True, parallel=True, readme=None, flow_library=None, extra_hash_data=(), write_hash_audit=True, hashing_level=1, dim_order=None, dim_exclude=None, bool_wrapping=False, with_root_node_name=None, ): assert isinstance(tree, DataTree) tree.digitize_relationships(inplace=True) self = super().__new__(cls) # clean defs with hidden values defs = {k: _flip_flop_def(v) for k, v in defs.items()} # start init up to flow_hash self.__initialize_1( tree, defs, cache_dir=cache_dir, extra_hash_data=extra_hash_data, hashing_level=hashing_level, dim_order=dim_order, dim_exclude=dim_exclude, error_model=error_model, boundscheck=boundscheck, nopython=nopython, fastmath=fastmath, bool_wrapping=bool_wrapping, ) # return from library if available if flow_library is not None and self.flow_hash in flow_library: logger.info(f"flow exists in library: {self.flow_hash}") result = flow_library[self.flow_hash] result.tree = tree return result # otherwise finish normal init self.__initialize_2( defs, error_model=error_model, name=name, dtype=dtype, boundscheck=boundscheck, nopython=nopython, fastmath=fastmath, readme=readme, parallel=parallel, extra_hash_data=extra_hash_data, write_hash_audit=write_hash_audit, with_root_node_name=with_root_node_name, ) if flow_library is not None: flow_library[self.flow_hash] = self self.with_root_node_name = with_root_node_name return self def __initialize_1( self, tree, defs, cache_dir=None, extra_hash_data=(), error_model="numpy", boundscheck=False, nopython=True, fastmath=True, hashing_level=1, dim_order=None, dim_exclude=None, bool_wrapping=False, ): """ Initialize up to the flow_hash. See main docstring for arguments. """ if cache_dir is None: import tempfile self.temp_cache_dir = tempfile.TemporaryDirectory() self.cache_dir = self.temp_cache_dir.name else: self.cache_dir = cache_dir self.tree = tree self._raw_functions = {} self._secondary_flows = {} self.dim_order = dim_order self.dim_exclude = dim_exclude self.bool_wrapping = bool_wrapping all_raw_names = set() all_name_tokens = set() for _k, expr in defs.items(): plain_names, attribute_pairs, subscript_pairs = extract_names_2(expr) all_raw_names |= plain_names if self.tree.root_node_name: all_raw_names |= attribute_pairs.get(self.tree.root_node_name, set()) all_raw_names |= subscript_pairs.get(self.tree.root_node_name, set()) dimensions_ordered = presorted( self.tree.sizes, self.dim_order, self.dim_exclude ) index_slots = {i: n for n, i in enumerate(dimensions_ordered)} self.arg_name_positions = index_slots self.arg_names = dimensions_ordered self.output_name_positions = {} self._used_extra_vars = {} if self.tree.extra_vars: for k, v in self.tree.extra_vars.items(): if k in all_raw_names: self._used_extra_vars[k] = v self._used_extra_funcs = set() if self.tree.extra_funcs: for f in self.tree.extra_funcs: if f.__name__ in all_raw_names: self._used_extra_funcs.add(f.__name__) self._used_aux_vars = [] for aux_var in self.tree.aux_vars: if aux_var in all_raw_names: self._used_aux_vars.append(aux_var) subspace_names = set() for k, _ in self.tree.subspaces_iter(): subspace_names.add(k) for k in self.tree.subspace_fallbacks: subspace_names.add(k) optional_get_tokens = ExtractOptionalGetTokens(from_names=subspace_names).check( defs.values() ) self._optional_get_tokens = [] if optional_get_tokens: for _spacename, _varname in optional_get_tokens: found = False if ( _spacename in self.tree.subspaces and _varname in self.tree.subspaces[_spacename] ): self._optional_get_tokens.append(f"__{_spacename}__{_varname}:True") found = True elif _spacename in self.tree.subspace_fallbacks: for _subspacename in self.tree.subspace_fallbacks[_spacename]: if _varname in self.tree.subspaces[_subspacename]: self._optional_get_tokens.append( f"__{_subspacename}__{_varname}:__{_spacename}__{_varname}" ) found = True break if not found: self._optional_get_tokens.append( f"__{_spacename}__{_varname}:False" ) self._hashing_level = hashing_level if self._hashing_level > 1: func_code, all_name_tokens = self.init_sub_funcs( defs, error_model=error_model, boundscheck=boundscheck, nopython=nopython, fastmath=fastmath, ) self._func_code = func_code self._namespace_names = sorted(all_name_tokens) else: self._func_code = None self._namespace_names = None self.encoding_dictionaries = {} # compute the complete hash including defs, used_extra_vars, and namespace_names # digest size 20 creates a base32 encoded 32 character flow_hash string flow_hash = hashlib.blake2b(digest_size=20) flow_hash_audit = [] def _flow_hash_push(x): nonlocal flow_hash, flow_hash_audit y = str(x) flow_hash.update(y.encode("utf8")) flow_hash_audit.append(y.replace("\n", "\n# ")) _flow_hash_push("---DataTree Flow---") for k, v in defs.items(): _flow_hash_push(k) _flow_hash_push(v) for k in sorted(self._used_extra_vars): v = self._used_extra_vars[k] _flow_hash_push(k) _flow_hash_push(v) for k in sorted(self._used_aux_vars): _flow_hash_push(f"aux_var:{k}") for k in sorted(self._used_extra_funcs): _flow_hash_push(f"func:{k}") for k in sorted(self._optional_get_tokens): _flow_hash_push(f"OPTIONAL:{k}") _flow_hash_push("---DataTree---") for k in self.arg_names: _flow_hash_push(f"arg:{k}") for k in self.tree._hash_features(): if self._hashing_level > 0 or True: # or not k.startswith("relationship:"): _flow_hash_push(k) if self.dim_order: _flow_hash_push("---dim-order---") for k in self.dim_order: _flow_hash_push(k) for sname, sdata in self.tree.subspaces_iter(): digital_encoding_hashes = set() for iname, idata in sdata.digital_encoding.info().items(): digital_encoding_hashes.add(f"digital_encoding:{sname}:{iname}:{idata}") # ensure these are hashed in a stable ordering for ihash in sorted(digital_encoding_hashes): _flow_hash_push(ihash) if self._hashing_level > 1: for k in sorted(self._namespace_names): if k.startswith("__base__"): continue _flow_hash_push(k) parts = k.split("__") if len(parts) > 2: try: digital_encoding = self.tree.subspaces[parts[1]][ "__".join(parts[2:]) ].attrs["digital_encoding"] except (AttributeError, KeyError): pass else: if digital_encoding: for de_k in sorted(digital_encoding.keys()): de_v = digital_encoding[de_k] if de_k == "dictionary": self.encoding_dictionaries[k] = de_v _flow_hash_push((k, "digital_encoding", de_k, de_v)) for k in extra_hash_data: _flow_hash_push(k) _flow_hash_push(f"boundscheck={boundscheck}") _flow_hash_push(f"error_model={error_model}") _flow_hash_push(f"fastmath={fastmath}") _flow_hash_push(f"bool_wrapping={bool_wrapping}") self.flow_hash = base64.b32encode(flow_hash.digest()).decode() self.flow_hash_audit = "]\n# [".join(flow_hash_audit) def _index_slots(self): return { i: n for n, i in enumerate( presorted(self.tree.sizes, self.dim_order, self.dim_exclude) ) } def init_sub_funcs( self, defs, error_model="numpy", boundscheck=False, nopython=True, fastmath=True, ): func_code = "" all_name_tokens = set() index_slots = { i: n for n, i in enumerate( presorted(self.tree.sizes, self.dim_order, self.dim_exclude) ) } self.arg_name_positions = index_slots candidate_names = self.tree.namespace_names() candidate_names |= set(f"__aux_var__{i}" for i in self.tree.aux_vars.keys()) meta_data = {} if self.tree.relationships_are_digitized: for spacename, spacearrays in self.tree.subspaces.items(): dim_slots = {} spacekeys = list(spacearrays.keys()) + list(spacearrays.coords.keys()) for k1 in spacekeys: try: spacearrays_vars = spacearrays._variables except AttributeError: spacearrays_vars = spacearrays try: toks, blends = self.tree._arg_tokenizer( spacename, spacearray=spacearrays_vars[k1], spacearrayname=k1, exclude_dims=self.dim_exclude, ) except ValueError: pass else: dim_slots[k1] = toks try: digital_encodings = spacearrays.digital_encoding.info() except AttributeError: digital_encodings = {} blenders = spacearrays.redirection.blenders meta_data[spacename] = (dim_slots, digital_encodings, blenders) else: for spacename, spacearrays in self.tree.subspaces.items(): dim_slots = {} spacekeys = list(spacearrays.keys()) + list(spacearrays.coords.keys()) for k1 in spacekeys: try: _dims = spacearrays._variables[k1].dims except AttributeError: _dims = spacearrays[k1].dims dim_slots[k1] = [index_slots[z] for z in _dims] try: digital_encodings = spacearrays.digital_encoding.info() except AttributeError: digital_encodings = {} blenders = spacearrays.redirection.blenders meta_data[spacename] = (dim_slots, digital_encodings, blenders) # write individual function files for each expression for n, (k, expr) in enumerate(defs.items()): expr = str(expr).lstrip() prior_expr = init_expr = expr other_way = True while other_way: other_way = False # if other_way is triggered, there may be residual other terms # that were not addressed, so this loop should be applied again. for spacename in self.tree.subspaces.keys(): dim_slots, digital_encodings, blenders = meta_data[spacename] try: expr = expression_for_numba( expr, spacename, dim_slots, dim_slots, digital_encodings=digital_encodings, extra_vars=self.tree.extra_vars, blenders=blenders, bool_wrapping=self.bool_wrapping, original_expr=init_expr, ) except KeyError as key_err: # there was an error, but lets make sure we process the # whole expression to rewrite all the things we can before # moving on to the fallback processing. expr = expression_for_numba( expr, spacename, dim_slots, dim_slots, digital_encodings=digital_encodings, extra_vars=self.tree.extra_vars, blenders=blenders, bool_wrapping=self.bool_wrapping, swallow_errors=True, original_expr=init_expr, ) # Now for the fallback processing... if ".." in key_err.args[0]: topkey, attrkey = key_err.args[0].split("..") else: raise # check if we can resolve this name on any other subspace for other_spacename in self.tree.subspace_fallbacks.get( topkey, [] ): dim_slots, digital_encodings, blenders = meta_data[ other_spacename ] try: expr = expression_for_numba( expr, spacename, dim_slots, dim_slots, digital_encodings=digital_encodings, prefer_name=other_spacename, extra_vars=self.tree.extra_vars, blenders=blenders, bool_wrapping=self.bool_wrapping, original_expr=init_expr, ) except KeyError as err: # noqa: F841 pass else: other_way = True # at least one variable was found in a fallback break if not other_way and "get" in expr: # any remaining "get" expressions with defaults should now use them try: expr = expression_for_numba( expr, spacename, dim_slots, dim_slots, digital_encodings=digital_encodings, extra_vars=self.tree.extra_vars, blenders=blenders, bool_wrapping=self.bool_wrapping, get_default=True, original_expr=init_expr, ) except KeyError as err: # noqa: F841 pass else: other_way = True # at least one variable was found in a get break # check if we can resolve this "get" on any other subspace for other_spacename in self.tree.subspace_fallbacks.get( topkey, [] ): dim_slots, digital_encodings, blenders = meta_data[ other_spacename ] try: expr = expression_for_numba( expr, spacename, dim_slots, dim_slots, digital_encodings=digital_encodings, prefer_name=other_spacename, extra_vars=self.tree.extra_vars, blenders=blenders, bool_wrapping=self.bool_wrapping, get_default=True, original_expr=init_expr, ) except KeyError as err: # noqa: F841 pass else: other_way = True # at least one variable was found in a fallback break if not other_way: raise if prior_expr == expr: # nothing was changed, break out of loop break else: # something was changed, run the loop again to confirm # nothing else needs to change prior_expr = expr # now process for subspace fallbacks for gd in [False, True]: # first run all these with get_default off, nothing drops to defaults # if we might find it later. Then do a second pass with get_default on. for ( alias_spacename, actual_spacenames, ) in self.tree.subspace_fallbacks.items(): for actual_spacename in actual_spacenames: dim_slots, digital_encodings, blenders = meta_data[ actual_spacename ] try: expr = expression_for_numba( expr, alias_spacename, dim_slots, dim_slots, digital_encodings=digital_encodings, prefer_name=actual_spacename, extra_vars=self.tree.extra_vars, blenders=blenders, bool_wrapping=self.bool_wrapping, get_default=gd, original_expr=init_expr, ) except KeyError: # there was an error, but lets make sure we process the # whole expression to rewrite all the things we can before # moving on to the fallback processing. expr = expression_for_numba( expr, alias_spacename, dim_slots, dim_slots, digital_encodings=digital_encodings, prefer_name=actual_spacename, extra_vars=self.tree.extra_vars, blenders=blenders, bool_wrapping=self.bool_wrapping, swallow_errors=True, get_default=gd, original_expr=init_expr, ) # now find instances where an identifier is previously created in this flow. expr = expression_for_numba( expr, "", (), self.output_name_positions, "_outputs", extra_vars=self.tree.extra_vars, bool_wrapping=self.bool_wrapping, original_expr=init_expr, ) aux_tokens = { k: ast.parse(f"__aux_var__{k}", mode="eval").body for k in self.tree.aux_vars.keys() } # now handle aux vars expr = expression_for_numba( expr, "", (), spacevars=aux_tokens, prefer_name="aux_var", extra_vars=self.tree.extra_vars, bool_wrapping=self.bool_wrapping, original_expr=init_expr, ) if (k == init_expr) and (init_expr == expr) and k.isidentifier(): logger.error(f"unable to rewrite '{k}' to itself") raise ValueError(f"unable to rewrite '{k}' to itself") logger.debug(f"[{k}] rewrite {init_expr} -> {expr}") if not candidate_names: raise ValueError("there are no candidate namespace names loaded") f_name_tokens, f_arg_tokens = filter_name_tokens(expr, candidate_names) all_name_tokens |= f_name_tokens argtokens = sorted(f_arg_tokens) argtokens_ = ", ".join(argtokens) if argtokens_: argtokens_ += ", " func_code += FUNCTION_TEMPLATE.format( expr=expr, fname=clean(k), argtokens=argtokens_, nametokens=", ".join(sorted(f_name_tokens)), error_model=error_model, extra_imports="\n".join( [ "from .extra_funcs import *" if self.tree.extra_funcs else "", "from .extra_vars import *" if self._used_extra_vars else "", ] ), boundscheck=boundscheck, nopython=nopython, fastmath=fastmath, init_expr=init_expr if k == init_expr else f"{k}: {init_expr}", ) self._raw_functions[k] = (init_expr, expr, f_name_tokens, argtokens) self.output_name_positions[k] = n return func_code, all_name_tokens def __initialize_2( self, defs, error_model="numpy", name=None, dtype="float32", boundscheck=False, nopython=True, fastmath=True, readme=None, parallel=True, extra_hash_data=(), write_hash_audit=True, with_root_node_name=None, ): """ Second step in initialization, only used if the flow is not cached. Parameters ---------- tree : DataTree defs : Dict[str,str] Gives the names and definitions for the columns to create in our generated table. error_model : {'numpy', 'python'}, default 'numpy' The error_model option controls the divide-by-zero behavior. Setting it to ‘python’ causes divide-by-zero to raise exception like CPython. Setting it to ‘numpy’ causes divide-by-zero to set the result to +/-inf or nan. cache_dir : Path-like, optional A location to write out generated python and numba code. If not provided, a unique temporary directory is created. name : str, optional The name of this Flow used for writing out cached files. If not provided, a unique name is generated. If `cache_dir` is given, be sure to avoid name conflicts with other flow's in the same directory. """ if self._hashing_level <= 1: func_code, all_name_tokens = self.init_sub_funcs( defs, error_model=error_model, boundscheck=boundscheck, nopython=nopython, fastmath=fastmath, ) self._func_code = func_code self._namespace_names = sorted(all_name_tokens) for k in sorted(self._namespace_names): if k.startswith("__base__"): continue parts = k.split("__") if len(parts) > 2: try: digital_encoding = self.tree.subspaces[parts[1]][ "__".join(parts[2:]) ].attrs["digital_encoding"] except (AttributeError, KeyError): pass else: if digital_encoding: for de_k in sorted(digital_encoding.keys()): de_v = digital_encoding[de_k] if de_k == "dictionary": self.encoding_dictionaries[k] = de_v # assign flow name based on hash unless otherwise given if name is None: name = f"flow_{self.flow_hash}" self.name = name # create the package directory for the flow if it does not exist os.makedirs(self.cache_dir, exist_ok=True) # if an existing __init__ file matches the hash, just use it init_file = os.path.join(self.cache_dir, self.name, "__init__.py") if os.path.isfile(init_file): with open(init_file) as f: content = f.read() s = re.search("""flow_hash = ['"](.*)['"]""", content) else: s = None if s and s.group(1) == self.flow_hash: logger.info(f"using existing flow code {self.flow_hash}") writing = False else: logger.info(f"writing fresh flow code {self.flow_hash}") writing = True if writing: dependencies = { "import numpy as np", "import numba as nb", "import pandas as pd", "import pyarrow as pa", "import xarray as xr", "import sharrow as sh", "import inspect", "import warnings", "from contextlib import suppress", "from numpy import log, exp, log1p, expm1", "from sharrow.maths import piece, hard_sigmoid, transpose_leading, clip, digital_decode", "from sharrow.sparse import get_blended_2, isnan_fast_safe", } func_code = self._func_code # write extra_funcs file, if there are any extra_funcs if self.tree.extra_funcs: try: import cloudpickle as pickle except ModuleNotFoundError: import pickle func_code += "\n\n# extra_funcs\n" for x_func in self.tree.extra_funcs: if x_func.__name__ in self._used_extra_funcs: if x_func.__module__ == "__main__": dependencies.add("import pickle") func_code += f"\n\n{x_func.__name__} = pickle.loads({repr(pickle.dumps(x_func))})\n" else: func_code += f"\n\nfrom {x_func.__module__} import {x_func.__name__}\n" # write extra_vars file, if there are any used extra_vars if self._used_extra_vars: try: import cloudpickle as pickle except ModuleNotFoundError: import pickle buffer = io.StringIO() # any_pickle = False for x_name, x_var in self._used_extra_vars.items(): if isinstance(x_var, (float, int, str)): buffer.write(f"{x_name} = {x_var!r}\n") else: buffer.write( f"{x_name} = pickle.loads({repr(pickle.dumps(x_var))})\n" ) dependencies.add("import pickle") with io.StringIO() as x_code: x_code.write("\n") x_code.write(buffer.getvalue()) func_code += "\n\n# extra_vars\n" func_code += x_code.getvalue() # write encoding dictionaries, if there are any used if len(self.encoding_dictionaries): dependencies.add("import pickle") try: import cloudpickle as pickle except ModuleNotFoundError: import pickle buffer = io.StringIO() for x_name, x_dict in self.encoding_dictionaries.items(): buffer.write( f"__encoding_dict{x_name} = pickle.loads({repr(pickle.dumps(x_dict))})\n" ) with io.StringIO() as x_code: x_code.write("\n") x_code.write(buffer.getvalue()) func_code += "\n\n# encoding dictionaries\n" func_code += x_code.getvalue() # write the master module for this flow os.makedirs(os.path.join(self.cache_dir, self.name), exist_ok=True) with rewrite( os.path.join(self.cache_dir, self.name, "__init__.py"), "wt" ) as f_code: f_code.write( textwrap.dedent( f""" # this module generated automatically using sharrow version {__version__} # generation time: {time.strftime('%d %B %Y %I:%M:%S %p')} """ )[1:] ) if readme: f_code.write( textwrap.indent( textwrap.dedent(readme), "# ", lambda line: True, ) ) f_code.write("\n\n") dependencies_ = set() for depend in sorted(dependencies): if depend.startswith("import ") and "." not in depend: f_code.write(f"{depend}\n") dependencies_.add(depend) dependencies -= dependencies_ for depend in sorted(dependencies): if depend.startswith("import "): f_code.write(f"{depend}\n") dependencies_.add(depend) dependencies -= dependencies_ for depend in sorted(dependencies): if depend.startswith("from ") and "from ." not in depend: f_code.write(f"{depend}\n") dependencies_.add(depend) dependencies -= dependencies_ for depend in sorted(dependencies): f_code.write(f"{depend}\n") f_code.write("\n\n# namespace names\n") for k in sorted(self._namespace_names): f_code.write(f"# - {k}\n") if extra_hash_data: f_code.write("\n\n# extra_hash_data\n") for k in extra_hash_data: f_code.write(f"# - {str(k)}\n") f_code.write("\n\n# function code\n") f_code.write(f"\n\n{blacken(func_code)}") f_code.write("\n\n# machinery code\n\n") if self.tree.relationships_are_digitized: if with_root_node_name is None: with_root_node_name = self.tree.root_node_name if with_root_node_name is None: with_root_node_name = self.tree.root_node_name root_dims = list( presorted( self.tree._graph.nodes[with_root_node_name][ "dataset" ].sizes, self.dim_order, self.dim_exclude, ) ) n_root_dims = len(root_dims) if n_root_dims == 1: js = "j0" elif n_root_dims == 2: js = "j0, j1" else: raise NotImplementedError( f"n_root_dims only supported up to 2, not {n_root_dims}" ) meta_code = [] meta_code_dot = [] for n, k in enumerate(self._raw_functions): f_name_tokens = self._raw_functions[k][2] f_arg_tokens = self._raw_functions[k][3] f_name_tokens = ", ".join(sorted(f_name_tokens)) f_args_j = ", ".join([f"j{argn[-1]}" for argn in f_arg_tokens]) if f_args_j: f_args_j += ", " meta_code.append( f"result[{js}, {n}] = ({clean(k)}({f_args_j}result[{js}], {f_name_tokens})).item()" ) meta_code_dot.append( f"intermediate[{n}] = ({clean(k)}({f_args_j}intermediate, {f_name_tokens})).item()" ) meta_code_stack = textwrap.indent( "\n".join(meta_code), " " * 12 ).lstrip() meta_code_stack_dot = textwrap.indent( "\n".join(meta_code_dot), " " * 12 ).lstrip() len_self_raw_functions = len(self._raw_functions) joined_namespace_names = "\n ".join( f"{nn}," for nn in self._namespace_names ) linefeed = "\n " if not meta_code_stack_dot: meta_code_stack_dot = "pass" if n_root_dims == 1: meta_template = IRUNNER_1D_TEMPLATE.format(**locals()).format( **locals() ) meta_template_dot = IDOTTER_1D_TEMPLATE.format( **locals() ).format(**locals()) line_template = ILINER_1D_TEMPLATE.format(**locals()).format( **locals() ) mnl_template = MNL_1D_TEMPLATE.format(**locals()).format( **locals() ) nl_template = NL_1D_TEMPLATE.format(**locals()).format( **locals() ) elif n_root_dims == 2: meta_template = IRUNNER_2D_TEMPLATE.format(**locals()).format( **locals() ) meta_template_dot = IDOTTER_2D_TEMPLATE.format( **locals() ).format(**locals()) line_template = ILINER_2D_TEMPLATE.format(**locals()).format( **locals() ) mnl_template = MNL_2D_TEMPLATE.format(**locals()).format( **locals() ) nl_template = "" else: raise ValueError(f"invalid n_root_dims {n_root_dims}") else: raise RuntimeError("digitization is now required") f_code.write(blacken(textwrap.dedent(line_template))) f_code.write("\n\n") f_code.write(blacken(textwrap.dedent(mnl_template))) f_code.write("\n\n") f_code.write(blacken(textwrap.dedent(nl_template))) f_code.write("\n\n") f_code.write(blacken(textwrap.dedent(meta_template))) f_code.write("\n\n") f_code.write(blacken(textwrap.dedent(meta_template_dot))) f_code.write("\n\n") f_code.write(blacken(self._spill(self._namespace_names))) if write_hash_audit: f_code.write("\n\n# hash audit\n# [") f_code.write(self.flow_hash_audit) f_code.write("]\n") f_code.write("\n\n") f_code.write( "# Greetings, tinkerer! The `flow_hash` included here is a safety \n" "# measure to prevent unknowing users creating a mess by modifying \n" "# the code in this module so that it no longer matches the expected \n" "# variable definitions. If you want to modify this code, you should \n" "# delete this hash to allow the code to run without any checks, but \n" "# you do so at your own risk. \n" ) f_code.write(f"flow_hash = {self.flow_hash!r}\n") abs_cache_dir = os.path.abspath(self.cache_dir) if str(abs_cache_dir) not in sys.path: logger.debug(f"inserting {abs_cache_dir} into sys.path") sys.path.insert(0, str(abs_cache_dir)) added_cache_dir_to_sys_path = True else: added_cache_dir_to_sys_path = False importlib.invalidate_caches() logger.debug(f"importing {self.name}") try: module = importlib.import_module(self.name) except ModuleNotFoundError: # maybe we got out in front of the file system, wait a beat and retry time.sleep(2) try: module = importlib.import_module(self.name) except ModuleNotFoundError: logger.error(f"- os.getcwd: {os.getcwd()}") for i in sys.path: logger.error(f"- sys.path: {i}") raise if added_cache_dir_to_sys_path: sys.path = sys.path[1:] self._runner = getattr(module, "runner", None) self._dotter = getattr(module, "dotter", None) self._irunner = getattr(module, "irunner", None) self._logit_ndims = getattr(module, "logit_ndims", None) self._imnl = getattr(module, "mnl_transform", None) self._imnl_plus1d = getattr(module, "mnl_transform_plus1d", None) self._inestedlogit = getattr(module, "nl_transform", None) self._idotter = getattr(module, "idotter", None) self._linemaker = getattr(module, "linemaker", None) if not writing: self.function_names = module.function_names self.output_name_positions = module.output_name_positions def load_raw(self, rg, args, runner=None, dtype=None, dot=None): assert isinstance(rg, DataTree) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=nb.NumbaExperimentalFeatureWarning ) assembled_args = [args.get(k) for k in self.arg_name_positions.keys()] for aa in assembled_args: if aa.dtype.kind != "i": warnings.warn( "position arguments are not all integers", stacklevel=2 ) try: if runner is None: if dot is None: runner_ = self._runner else: runner_ = self._dotter else: runner_ = runner named_args = inspect.getfullargspec(runner_.py_func).args arguments = [] for arg in named_args: if arg in {"dtype", "dotarray", "inputarray", "argarray"}: continue if arg.startswith("_arg"): continue arg_value = rg.get_named_array(arg) # aux_vars get passed through as is, not forced to be arrays if arg.startswith("__aux_var"): arguments.append(arg_value) else: arg_value_array = np.asarray(arg_value) if arg_value_array.dtype.kind == "O": # convert object arrays to unicode str # and replace missing values with NAK='\u0015' # that can be found by `isnan_fast_safe` # This is done for compatability and likely ruins performance arg_value_array_ = arg_value_array.astype("unicode") arg_value_array_[pd.isnull(arg_value_array)] = "\u0015" arg_value_array = arg_value_array_ arguments.append(arg_value_array) kwargs = {} if dtype is not None: kwargs["dtype"] = dtype # else: # kwargs['dtype'] = np.float64 if dot is not None: kwargs["dotarray"] = dot # logger.debug(f"load_raw calling runner with {assembled_args.shape=}, {assembled_inputs.shape=}") return runner_(*assembled_args, *arguments, **kwargs) except nb.TypingError as err: _raw_functions = getattr(self, "_raw_functions", {}) logger.error(f"nb.TypingError in {len(_raw_functions)} functions") for k, v in _raw_functions.items(): logger.error(f"{k} = {v[0]} = {v[1]}") if "NameError:" in err.args[0]: import re problem = re.search("NameError: (.*)\x1b", err.args[0]) if problem: raise NameError(problem.group(1)) from err problem = re.search("NameError: (.*)\n", err.args[0]) if problem: raise NameError(problem.group(1)) from err raise except KeyError as err: # raise the inner key error which is more helpful context = getattr(err, "__context__", None) if context: raise context from None else: raise err def _iload_raw( self, rg, runner=None, dtype=None, dot=None, mnl=None, pick_counted=False, logsums=False, nesting=None, mask=None, compile_watch=False, ): assert isinstance(rg, DataTree) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=nb.NumbaExperimentalFeatureWarning ) try: known_arg_names = { "dtype", "dotarray", "argshape", "random_draws", "pick_counted", "logsums", "choice_dtype", "pick_count_dtype", "mask", } if runner is None: if mnl is not None: if nesting is None: if dot.shape[1] > 1: runner_ = self._imnl_plus1d else: runner_ = self._imnl else: runner_ = self._inestedlogit known_arg_names.update( { "n_nodes", "n_alts", "edges_up", "edges_dn", "mu_params", "start_slots", "len_slots", } ) elif dot is None: runner_ = self._irunner known_arg_names.update({"mask"}) if ( mask is not None and dtype is not None and not np.issubdtype(dtype, np.floating) ): raise TypeError("cannot use mask unless dtype is float") else: runner_ = self._idotter else: runner_ = runner try: fullargspec = inspect.getfullargspec(runner_.py_func) except AttributeError: fullargspec = inspect.getfullargspec(runner_) named_args = fullargspec.args arguments = [] _arguments_names = [] for arg in named_args: if arg in known_arg_names: continue argument = rg.get_named_array(arg) # aux_vars get passed through as is, not forced to be arrays if arg.startswith("__aux_var"): arguments.append(argument) else: if argument.dtype.kind == "O": # convert object arrays to unicode str # and replace missing values with NAK='\u0015' # that can be found by `isnan_fast_safe` # This is done for compatability and likely ruins performance argument_ = argument.astype("unicode") argument_[pd.isnull(argument)] = "\u0015" arguments.append(np.asarray(argument_)) else: arguments.append(np.asarray(argument)) _arguments_names.append(arg) kwargs = {} if dtype is not None: kwargs["dtype"] = dtype if dot is not None: kwargs["dotarray"] = np.asarray(dot) if mnl is not None: kwargs["random_draws"] = mnl kwargs["pick_counted"] = pick_counted kwargs["logsums"] = logsums if nesting is not None: nesting.pop("edges_1st", None) # unused in simple NL nesting.pop("edges_alloc", None) # unused in simple NL kwargs.update(nesting) if mask is not None: kwargs["mask"] = mask if self.with_root_node_name is None: tree_root_dims = rg.root_dataset.sizes else: tree_root_dims = rg._graph.nodes[self.with_root_node_name][ "dataset" ].sizes argshape = [ tree_root_dims[i] for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude) ] if mnl is not None: if nesting is not None: n_alts = nesting["n_alts"] elif len(argshape) == 2: n_alts = argshape[1] else: n_alts = kwargs["dotarray"].shape[1] if n_alts < 128: kwargs["choice_dtype"] = np.int8 elif n_alts < 32768: kwargs["choice_dtype"] = np.int16 if logger.isEnabledFor(logging.DEBUG): logger.debug( "========= PASSING ARGUMENT TO SHARROW LOAD ==========" ) logger.debug(f"{argshape=}") for _name, _info in zip(_arguments_names, arguments): try: logger.debug(f"ARG {_name}: {_info.dtype}, {_info.shape}") except AttributeError: alt_repr = repr(_info) if len(alt_repr) < 200: logger.debug(f"ARG {_name}: {alt_repr}") else: logger.debug(f"ARG {_name}: type={type(_info)}") for _name, _info in kwargs.items(): try: logger.debug(f"KWARG {_name}: {_info.dtype}, {_info.shape}") except AttributeError: alt_repr = repr(_info) if len(alt_repr) < 200: logger.debug(f"KWARG {_name}: {alt_repr}") else: logger.debug(f"KWARG {_name}: type={type(_info)}") logger.debug( "========= ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ==========" ) result = runner_(np.asarray(argshape), *arguments, **kwargs) if compile_watch: self.check_cache_misses( runner_, log_details=compile_watch != "simple" ) return result except nb.TypingError as err: _raw_functions = getattr(self, "_raw_functions", {}) logger.error(f"nb.TypingError in {len(_raw_functions)} functions") for k, v in _raw_functions.items(): logger.error(f"{k} = {v[0]} = {v[1]}") if "NameError:" in err.args[0]: import re problem = re.search("NameError: (.*)\x1b", err.args[0]) if problem: raise NameError(problem.group(1)) from err problem = re.search("NameError: (.*)\n", err.args[0]) if problem: raise NameError(problem.group(1)) from err raise # except KeyError as err: # # raise the inner key error which is more helpful # context = getattr(err, "__context__", None) # if context: # raise context # else: # raise err def check_cache_misses(self, *funcs, fresh=True, log_details=True): self.compiled_recently = False if not hasattr(self, "_known_cache_misses"): self._known_cache_misses = {} if not funcs: funcs = ( self._imnl, self._imnl_plus1d, self._inestedlogit, self._irunner, self._idotter, ) for f in funcs: if f is None: continue try: fullargspec = inspect.getfullargspec(f.py_func) except AttributeError: fullargspec = inspect.getfullargspec(f) named_args = fullargspec.args cache_misses = f.stats.cache_misses runner_name = f.__name__ if cache_misses: if runner_name not in self._known_cache_misses: self._known_cache_misses[runner_name] = {} if fresh: known_cache_misses = self._known_cache_misses[runner_name] else: known_cache_misses = {} for k, v in cache_misses.items(): if v > known_cache_misses.get(k, 0): if log_details: warning_text = "\n".join( f" - {argname}: {sig}" for (sig, argname) in zip(k, named_args) ) warning_text = f"\n{runner_name}(\n{warning_text}\n)" else: warning_text = "" timers = ( f.overloads[k] .metadata["timers"] .get("compiler_lock", "N/A") ) if isinstance(timers, float): if timers < 1e-3: timers = f"{timers/1e-6:.0f} µs" elif timers < 1: timers = f"{timers/1e-3:.1f} ms" else: timers = f"{timers:.2f} s" logger.warning( f"cache miss in {self.flow_hash}{warning_text}\n" f"Compile Time: {timers}" ) warnings.warn( f"{self.flow_hash}", CacheMissWarning, stacklevel=1 ) self.compiled_recently = True self._known_cache_misses[runner_name][k] = v return self.compiled_recently @property def cache_misses(self): """dict[str, dict]: Numba cache misses across all defined flow methods.""" misses = {} for k, v in self.__dict__.items(): from numba.core.dispatcher import Dispatcher if isinstance(v, Dispatcher): misses[k] = v.stats.cache_misses.copy() return misses def _load( self, source=None, as_dataframe=False, as_dataarray=False, as_table=False, runner=None, dtype=None, dot=None, logit_draws=None, pick_counted=False, compile_watch=False, logsums=0, nesting=None, mask=None, ): """ Compute the flow outputs. Parameters ---------- source : DataTree, optional This is the source of the data for this flow. If not provided, the tree used to initialize this flow is used. as_dataframe : bool, default False Return the loaded data as a pandas.DataFrame. Must not be used in conjunction with the `dot` argument. as_dataarray : bool, default False Return the loaded data as a xarray.DataArray. as_table : bool, default False Return the loaded data as a sharrow.Table (a subclass of pyarrow.Table). runner : Callable, optional Overload the prepared function with a different callable. Recommended for advanced usage only. dtype : str or dtype Override the default dtype for the result. May trigger re-compilation of the underlying code. dot : array-like, optional An array of coefficients. If provided, the function returns the dot-product of the computed expressions and this array of coefficients, but without ever materializing the array of computed expression values in memory, achiving significant performance gains. logit_draws : array-like, optional An array of random values in the unit interval. If provided, `dot` must also be provided. The dot-product is treated as the utility function for a multinomial logit model, and these draws are used to simulate choices from the implied probabilities. compile_watch : bool, default False Watch for compiled code. logsums : int, default 0 Set to 1 to return only logsums instead of making draws from logit models. Set to 2 to return both logsums and draws. nesting : dict, optional Nesting arrays mask : array-like, optional """ if compile_watch: compile_watch = time.time() if (as_dataframe or as_table) and dot is not None: raise ValueError("cannot format output other than as array if using dot") if source is None: source = self.tree if dtype is None and dot is not None: dtype = dot.dtype if logit_draws is None and logsums == 1: logit_draws = np.zeros(source.shape + (0,), dtype=dtype) if self.with_root_node_name is None: use_dims = list( presorted(source.root_dataset.sizes, self.dim_order, self.dim_exclude) ) else: use_dims = list( presorted( source._graph.nodes[self.with_root_node_name]["dataset"].sizes, self.dim_order, self.dim_exclude, ) ) if logit_draws is not None: if dot is None: raise NotImplementedError if dot.ndim == 1 or (dot.ndim == 2 and dot.shape[1] == 1): while logit_draws.ndim < self._logit_ndims: logit_draws = np.expand_dims(logit_draws, -1) else: while logit_draws.ndim < self._logit_ndims + 1: logit_draws = np.expand_dims(logit_draws, -1) result_dims = None result_squeeze = None if dot is None: # returning extracted raw data, with all dims plus expressions result_dims = use_dims + ["expressions"] result_squeeze = None else: if not isinstance(dot, xr.DataArray): dot_trailing_dim = ["ALT_COL"] else: dot_trailing_dim = [dot.dims[1]] if dot.ndim == 1 and logit_draws is None: # returning a dot-product for idca-type data result_dims = use_dims result_squeeze = (-1,) elif dot.ndim == 2 and logit_draws is None: # returning a dot-product for idco-type data result_dims = use_dims + dot_trailing_dim result_squeeze = None elif dot.ndim > 2 and logit_draws is None: raise NotImplementedError else: # returning a logit model result if not isinstance(logit_draws, xr.DataArray): logit_draws_trailing_dim = ["DRAW"] else: logit_draws_trailing_dim = [logit_draws.dims[-1]] if dot.ndim == 1 and logit_draws.ndim == len(use_dims): result_dims = use_dims[:-1] + logit_draws_trailing_dim elif ( dot.ndim == 2 and dot.shape[1] == 1 and logit_draws.ndim == len(use_dims) and logit_draws.shape[-1] == 1 ): result_dims = use_dims[:-1] result_squeeze = (-1,) elif ( dot.ndim == 2 and dot.shape[1] == 1 and logit_draws.ndim == len(use_dims) ): result_dims = use_dims[:-1] + logit_draws_trailing_dim elif dot.ndim == 2 and logit_draws.ndim == len(use_dims): result_dims = use_dims[:-1] + dot_trailing_dim elif dot.ndim == 1 and logit_draws.ndim == len(use_dims) + 1: result_dims = use_dims[:-1] + logit_draws_trailing_dim if logit_draws.shape[-1] == 1: result_squeeze = (-1,) elif ( dot.ndim == 2 and logit_draws.ndim == len(use_dims) + 1 and logit_draws.shape[-1] == 1 and self._logit_ndims == 1 ): result_dims = use_dims result_squeeze = (-1,) elif ( dot.ndim == 2 and logit_draws.ndim == len(use_dims) + 1 and logit_draws.shape[-1] > 1 and self._logit_ndims == 1 ): result_dims = use_dims + logit_draws_trailing_dim elif ( dot.ndim == 2 and logit_draws.ndim == len(use_dims) + 1 and logit_draws.shape[-1] == 0 ): # logsums only result_dims = use_dims result_squeeze = (-1,) elif ( dot.ndim == 2 and logit_draws.ndim == len(use_dims) + 1 and logit_draws.shape[-1] > 1 and self._logit_ndims == 2 ): # wide choices result_dims = use_dims + logit_draws_trailing_dim elif ( dot.ndim == 2 and logit_draws.ndim == len(use_dims) + 1 and logit_draws.shape[-1] == 1 and self._logit_ndims == 2 ): # wide choices result_dims = use_dims result_squeeze = (-1,) else: print(f"{dot.ndim=}") print(f"{logit_draws.ndim=}") print(f"{len(use_dims)=}") print(f"{self._logit_ndims=}") raise NotImplementedError() # dot_collapse = False result_p = None pick_count = None out_logsum = None if dot is not None and dot.ndim == 1: dot = np.expand_dims(dot, -1) # dot_collapse = True # mnl_collapse = False # idca_collapse = False # if logit_draws is not None and logit_draws.ndim == 1: # logit_draws = np.expand_dims(logit_draws, -1) # mnl_collapse = True # elif ( # logit_draws is not None # and logit_draws.ndim == 2 # and dot.ndim == 2 # and dot.shape[1] == 1 # ): # idca_collapse = True if not source.relationships_are_digitized: source = source.digitize_relationships() if source.relationships_are_digitized: if logit_draws is None: result = self._iload_raw( source, runner=runner, dtype=dtype, dot=dot, mask=mask, compile_watch=compile_watch, ) else: result, result_p, pick_count, out_logsum = self._iload_raw( source, runner=runner, dtype=dtype, dot=dot, mnl=logit_draws, pick_counted=pick_counted, logsums=logsums, nesting=nesting, mask=mask, compile_watch=compile_watch, ) pick_count = zero_size_to_None(pick_count) out_logsum = zero_size_to_None(out_logsum) else: raise RuntimeError("please digitize") if as_dataframe: index = getattr(source.root_dataset, "index", None) result = pd.DataFrame( result, index=index, columns=list(self._raw_functions.keys()) ) elif as_table: result = Table( {k: result[:, n] for n, k in enumerate(self._raw_functions.keys())} ) elif as_dataarray: if result_squeeze: result = squeeze(result, result_squeeze) result_p = squeeze(result_p, result_squeeze) pick_count = squeeze(pick_count, result_squeeze) if self.with_root_node_name is None: result_coords = { k: v for k, v in source.root_dataset.coords.items() if k in result_dims } else: result_coords = { k: v for k, v in source._graph.nodes[self.with_root_node_name][ "dataset" ].coords.items() if k in result_dims } if result is not None: result = xr.DataArray( result, dims=result_dims, coords=result_coords, ) if "expressions" in result_dims: result.coords["expressions"] = self.function_names if result_p is not None: result_p = xr.DataArray( result_p, dims=result_dims, coords=result_coords, ) if pick_count is not None: pick_count = xr.DataArray( pick_count, dims=result_dims, coords=result_coords, ) if out_logsum is not None: out_logsum = xr.DataArray( out_logsum, dims=result_dims[: out_logsum.ndim], coords={ k: v for k, v in source.root_dataset.coords.items() if k in result_dims[: out_logsum.ndim] }, ) else: if result_squeeze: result = squeeze(result, result_squeeze) result_p = squeeze(result_p, result_squeeze) pick_count = squeeze(pick_count, result_squeeze) # if compile_watch: # self.compiled_recently = False # for i in os.walk(os.path.join(self.cache_dir, self.name)): # for f in i[2]: # fi = os.path.join(i[0], f) # try: # t = os.path.getmtime(fi) # except FileNotFoundError: # # something is actively happening in this directory # self.compiled_recently = True # logger.warning( # f"unidentified activity (file deletion) detected for {self.name}" # ) # break # if t > compile_watch: # self.compiled_recently = True # logger.warning(f"compilation activity detected for {self.name}") # break # if self.compiled_recently: # break if not compile_watch: try: del self.compiled_recently except AttributeError: pass if out_logsum is not None: return result, result_p, pick_count, out_logsum if pick_count is not None: return result, result_p, pick_count if result_p is not None: return result, result_p return result
[docs] def load(self, source=None, dtype=None, compile_watch=False, mask=None): """ Compute the flow outputs as a numpy array. Parameters ---------- source : DataTree, optional This is the source of the data for this flow. If not provided, the tree used to initialize this flow is used. dtype : str or dtype Override the default dtype for the result. May trigger re-compilation of the underlying code. compile_watch : bool, default False Set the `compiled_recently` flag on this flow to True if any file modification activity is observed in the cache directory. mask : array-like, optional Only compute values for items where mask is truthy. Returns ------- numpy.array """ return self._load( source=source, dtype=dtype, compile_watch=compile_watch, mask=mask )
[docs] def load_dataframe(self, source=None, dtype=None, compile_watch=False, mask=None): """ Compute the flow outputs as a pandas.DataFrame. Parameters ---------- source : DataTree, optional This is the source of the data for this flow. If not provided, the tree used to initialize this flow is used. dtype : str or dtype Override the default dtype for the result. May trigger re-compilation of the underlying code. compile_watch : bool, default False Set the `compiled_recently` flag on this flow to True if any file modification activity is observed in the cache directory. mask : array-like, optional Only compute values for items where mask is truthy. Returns ------- pandas.DataFrame """ return self._load( source=source, dtype=dtype, as_dataframe=True, compile_watch=compile_watch, mask=mask, )
[docs] def load_dataarray(self, source=None, dtype=None, compile_watch=False, mask=None): """ Compute the flow outputs as a xarray.DataArray. Parameters ---------- source : DataTree, optional This is the source of the data for this flow. If not provided, the tree used to initialize this flow is used. dtype : str or dtype Override the default dtype for the result. May trigger re-compilation of the underlying code. compile_watch : bool, default False Set the `compiled_recently` flag on this flow to True if any file modification activity is observed in the cache directory. mask : array-like, optional Only compute values for items where mask is truthy. Returns ------- xarray.DataArray """ return self._load( source=source, dtype=dtype, as_dataarray=True, compile_watch=compile_watch, mask=mask, )
[docs] def dot(self, coefficients, source=None, dtype=None, compile_watch=False): """ Compute the dot-product of expression results and coefficients. Parameters ---------- coefficients : array-like This function will return the dot-product of the computed expressions and this array of coefficients, but without ever materializing the array of computed expression values in memory, achieving significant performance gains. source : DataTree, optional This is the source of the data for this flow. If not provided, the tree used to initialize this flow is used. dtype : str or dtype Override the default dtype for the result. May trigger re-compilation of the underlying code. compile_watch : bool, default False Set the `compiled_recently` flag on this flow to True if any file modification activity is observed in the cache directory. Returns ------- numpy.ndarray """ return self._load( source, dot=coefficients, dtype=dtype, compile_watch=compile_watch, )
[docs] def dot_dataarray(self, coefficients, source=None, dtype=None, compile_watch=False): """ Compute the dot-product of expression results and coefficients. Parameters ---------- coefficients : DataArray This function will return the dot-product of the computed expressions and this array of coefficients, but without ever materializing the array of computed expression values in memory, achieving significant performance gains. source : DataTree, optional This is the source of the data for this flow. If not provided, the tree used to initialize this flow is used. dtype : str or dtype Override the default dtype for the result. May trigger re-compilation of the underlying code. compile_watch : bool, default False Set the `compiled_recently` flag on this flow to True if any file modification activity is observed in the cache directory. Returns ------- xarray.DataArray """ return self._load( source, dot=coefficients, dtype=dtype, as_dataarray=True, compile_watch=compile_watch, )
[docs] def logit_draws( self, coefficients, draws=None, source=None, pick_counted=False, logsums=0, dtype=None, compile_watch=False, nesting=None, as_dataarray=False, mask=None, ): """ Make random simulated choices for a multinomial logit model. Parameters ---------- coefficients : array-like These coefficients are used is in `dot` to compute the dot-product of the computed expressions, and this result is treated as the utility function for a multinomial logit model. draws : array-like A one or two dimensional array of random values in the unit interval. If one dimensional, then it must have length equal to the first dimension of the base `shape` of `source`, and a single draw will be applied for each row in that dimension. If two dimensional, the first dimension must match as above, and the second dimension determines the number of draws applied for each row in the first dimension. source : DataTree, optional This is the source of the data for this flow. If not provided, the tree used to initialize this flow is used. pick_counted : bool, default False Whether to tally multiple repeated choices with a pick count. logsums : int, default 0 Set to 1 to return only logsums instead of making draws from logit models. Set to 2 to return both logsums and draws. dtype : str or dtype Override the default dtype for the probability. May trigger re-compilation of the underlying code. The choices and pick counts (if included) are always integers. compile_watch : bool, default False Set the `compiled_recently` flag on this flow to True if any file modification activity is observed in the cache directory. nesting : dict, optional Nesting instructions as_dataarray : bool, default False mask : array-like, optional Only compute values for items where mask is truthy. Returns ------- choices : array[int32] The positions of the simulated choices. probs : array[dtype] The probability that was associated with each simulated choice. pick_count : array[int32], optional A count of how many times this choice was chosen, only included if `pick_counted` is True. """ return self._load( source=source, dot=coefficients, logit_draws=draws, dtype=dtype, pick_counted=pick_counted, compile_watch=compile_watch, logsums=np.int8(logsums), nesting=nesting, as_dataarray=as_dataarray, mask=mask, )
@property def defs(self): return {k: v[0] for (k, v) in self._raw_functions.items()} @property def function_names(self): return list(self._raw_functions.keys()) @function_names.setter def function_names(self, x): for name in x: if name not in self._raw_functions: self._raw_functions[name] = (None, None, set(), []) def _spill(self, all_name_tokens=None): cmds = ["\n"] cmds.append(f"output_name_positions = {self.output_name_positions!r}") cmds.append(f"function_names = {self.function_names!r}") return "\n".join(cmds)
[docs] def show_code(self, linenos="inline"): """ Display the underlying Python code constructed for this flow. This convenience function is provided primarily to display the underlying source code in a Jupyter notebook, for debugging and educational purposes. Parameters ---------- linenos : {'inline', 'table'} This argument is passed to the pygments HtmlFormatter. If set to ``'table'``, output line numbers as a table with two cells, one containing the line numbers, the other the whole code. This is copy-and-paste-friendly, but may cause alignment problems with some browsers or fonts. If set to ``'inline'``, the line numbers will be integrated in the ``<pre>`` tag that contains the code. Returns ------- IPython.display.HTML """ from IPython.display import HTML from pygments import highlight from pygments.formatters.html import HtmlFormatter from pygments.lexers.python import PythonLexer codefile = os.path.join(self.cache_dir, self.name, "__init__.py") with open(codefile) as f_code: code = f_code.read() pretty = highlight(code, PythonLexer(), HtmlFormatter(linenos=linenos)) css = HtmlFormatter().get_style_defs(".highlight") bbedit_url = f"x-bbedit://open?url=file://{codefile}" bb_link = f'<a href="{bbedit_url}">{codefile}</a>' return HTML(f"<style>{css}</style><p>{bb_link}</p>{pretty}")
def init_streamer(self, source=None, dtype=None): """ Initialize a compiled closure on the data for loading individual lines. Parameters ---------- source : DataTree, optional This is the source of the data for this flow. If not provided, the tree used to initialize this flow is used. dtype : str or dtype, default float32 Override the default dtype for the result. May trigger re-compilation of the underlying code. Returns ------- callable """ if source is None: source = self.tree if dtype is None: dtype = np.float32 named_args = inspect.getfullargspec(self._linemaker.py_func).args skip_args = ["intermediate", "j0", "j1"] named_args = tuple(i for i in named_args if i not in skip_args) general_mapping = {} for k, v in source.subspaces.items(): for i in v: mangled_key = f"__{k}__{i}" if mangled_key in named_args: general_mapping[mangled_key] = v[i].to_numpy() for i in v.indexes: mangled_key = f"__{k}__{i}" if mangled_key in named_args: general_mapping[mangled_key] = v[i].to_numpy() selected_args = tuple(general_mapping[k] for k in named_args) len_self_raw_functions = len(self._raw_functions) tree_root_dims = source.root_dataset.sizes argshape = tuple( tree_root_dims[i] for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude) ) if len(argshape) == 1: linemaker = self._linemaker @nb.njit def streamer(c, out=None): if out is None: result = np.zeros(len_self_raw_functions, dtype=dtype) else: result = out assert result.ndim == 1 assert result.size == len_self_raw_functions linemaker(result, c, *selected_args) return result elif len(argshape) == 2: n_alts = argshape[1] linemaker = self._linemaker @nb.njit def streamer(c, out=None): if out is None: result = np.zeros((n_alts, len_self_raw_functions), dtype=dtype) else: result = out assert result.shape == (n_alts, len_self_raw_functions) for i in range(n_alts): linemaker(result[i, :], c, i, *selected_args) return result else: raise NotImplementedError( f"root tree with {len(argshape)} dims {argshape=}" ) return streamer