# Sharrow Basics

This notebook provides a short walkthrough of some of the basic features of the `sharrow` library.

In [None]:
from io import StringIO

import numba as nb
import numpy as np
import pandas as pd
import xarray as xr

import sharrow as sh

sh.__version__

In [None]:
# check versions
import packaging

assert packaging.version.parse(xr.__version__) >= packaging.version.parse("0.20.2")

## Example Data

We'll begin by importing some example data to work with.  We'll be using 
some test data taken from the MTC example in the ActivitySim project, including 
tables of data for households and persons, as well as a set of 
skims containing transportation level of service information for travel around
a tiny slice of San Francisco.

The households and persons are typical tabular data, and 
each can be read in and stored as a `pandas.DataFrame`.

In [None]:
households = sh.example_data.get_households()
households.head()

In [None]:
# TEST households content
assert len(households) == 5000
assert "income" in households
assert households.index.name == "HHID"

In [None]:
persons = sh.example_data.get_persons()
persons.head()

In [None]:
assert len(persons) == 8212
assert "household_id" in persons
assert persons.index.name == "PERID"

The skims, on the other hand, are not just simple tabular data, but rather a 
multi-dimensional representation of the transportation system, indexed by origin.
destination, and time of day. Rather than using a single DataFrame for this data,
we store it as a multi-dimensional `xarray.Dataset`.

In [None]:
skims = sh.example_data.get_skims()
skims

For tabular data, sharrow can be provided either pandas DataFrames or xarray Datasets, 
but to ensure consistency the former are converted into the latter automatically when
they are used with sharrow.  You can also easily manually make the conversion:

In [None]:
xr.Dataset(persons)

Suppose we're wanting to simulate a tour mode choice.  Normally we'd probably have
run through a bunch of different models to generate these tours and their destinations
first, but let's just skip that for now and make up some random data to work with.  We'll 
just randomly choose (with replacement) 100,000 people, and send them to 100,000 zones, with
random outbound and inbound time periods.

In [None]:
def random_tours(n_tours=100_000, seed=42):
    rng = np.random.default_rng(seed)
    n_zones = skims.dims["dtaz"]
    return pd.DataFrame(
        {
            "PERID": rng.choice(persons.index, size=n_tours),
            "dest_taz_idx": rng.choice(n_zones, size=n_tours),
            "out_time_period": rng.choice(skims.time_period, size=n_tours),
            "in_time_period": rng.choice(skims.time_period, size=n_tours),
        }
    ).rename_axis("TOURIDX")


tours = random_tours()
tours.head()

In [None]:
# TEST
assert tours.index.name == "TOURIDX"
assert 0 in tours.head().dest_taz_idx

Of note in this table, we include include destination TAZ's by index (position) not 
label, so we can observe a TAZ index of `0` even though the first TAZ ID is 1.

## Spec Files

Now that we've got our tours to work with, we'll also need 
an expression "spec" file that defines the utility function
terms and coefficients.  Following the ActivitySim format, we
can write a mini-spec file as appears below.  Each line of this
CSV file has an expression that can be evaluated in the context
of the various tables and datasets shown above, plus a set of 
coefficients that apply for that expression across various modal 
alternatives (drive, walk, and transit in this example).

In [None]:
mini_spec = """
Label,Expression,DRIVE,WALK,TRANSIT
Drive Time,odt_skims['SOV_TIME'] + dot_skims['SOV_TIME'],-0.0134,,
Transit IVT,(odt_skims['WLK_LOC_WLK_TOTIVT']/100 + dot_skims['WLK_LOC_WLK_TOTIVT']/100),,,-0.0134
Transit Wait Time,short_i_wait_mult * ((odt_skims['WLK_LOC_WLK_IWAIT']/100).clip(upper=shortwait) + (dot_skims['WLK_LOC_WLK_IWAIT']/100).clip(upper=shortwait)),,,-0.0134
Income,hh.income > income_breakpoints[2],,-0.2,
Constant,one,,-0.4,-0.55
"""

We'll use pandas to load these values into a DataFrame.

In [None]:
spec = pd.read_csv(StringIO(mini_spec), index_col="Label")
spec

In [None]:
# TEST check spec
assert spec.index.name == "Label"
assert all(spec.columns == ["Expression", "DRIVE", "WALK", "TRANSIT"])

## Data Trees and Flows

Then, it's time to prepare our data.  We'll create a `DataTree`
that defines the relationships among all the datasets we're working
with.  This is a tree in the mathematical sense, with nodes referencing
the datasets and edges representing the relationships.

In [None]:
income_breakpoints = nb.typed.Dict.empty(nb.types.int32, nb.types.int32)
income_breakpoints[0] = 15000
income_breakpoints[1] = 30000
income_breakpoints[2] = 60000

tree = sh.DataTree(
    tour=tours,
    person=persons,
    hh=households,
    odt_skims=skims,
    dot_skims=skims,
    relationships=(
        "tour.PERID @ person.PERID",
        "person.household_id @ hh.HHID",
        "hh.TAZ @ odt_skims.otaz",
        "tour.dest_taz_idx -> odt_skims.dtaz",
        "tour.out_time_period @ odt_skims.time_period",
        "tour.dest_taz_idx -> dot_skims.otaz",
        "hh.TAZ @ dot_skims.dtaz",
        "tour.in_time_period @ dot_skims.time_period",
    ),
    extra_vars={
        "shortwait": 3,
        "one": 1,
    },
    aux_vars={
        "short_i_wait_mult": 0.75,
        "income_breakpoints": income_breakpoints,
    },
)

The first named dataset we include, `tour`, is by default the root node of this data tree.
We then can define an arbitrary number of other named data nodes.  Here, we add `person`, `hh`,
`odt_skims` and `odt_skims`.  Note that these last two are actually two different names for the
same underlying dataset, and for each name we will next define a unique set of relationships.

All data nodes in this tree are stored as `Dataset` objects. We can give a pandas DataFrame
in this contructor instead, but it will be automatically converted into a one-dimension `Dataset`.
The conversion is no-copy if possible (and it is usually possible) so no additional memory is
consumed in the conversion.

The `relationships` defines links of the data tree. Each relationship maps a particular variable
in a named upstream dataset to a particular dimension of a named downstream dataset.  For example,
`"person.household_id @ hh.HHID"` tells the tree that the `household_id` variable in the `person` 
dataset contains labels (`@`) that map to the `HHID` dimension of the `hh` dataset.

In addition to mapping by label, we can also map by position, by using the `->` operator in the
relationship string instead of `@`.  In the example above, we map the tour destination TAZ's in
this manner, as the `dest_taz_idx` variable in the `tours` dataset contains positional references
instead of labels.

A special case for the relationship mapping is available when the source varibable
in the upstream dataset is explicitly categorical.  In this case, sharrow checks that
the categories exactly match the labels in the referenced downstream dataset dimension,
and that there are no missing categorical values. If they do match and there are no
missing values, the code points of the categories are used as positional mapping
references, which is both memory and runtime efficient.  If they *don't* match, an
error is raised, as it is presumed that the user has made a mistake... in theory 
sharrow could unravel the category values and do the mapping by label, but this would
be a cumbersome operation, contrary to the efficiency goals of the library.

Lastly, our tree definition includes a few named constants, that are just fixed values defined
in a separate dictionary. These are shown in two groups, `extra_vars` and `aux_vars`. The values 
in `extra_vars` get hard-coded into the compiled results, effectively the 
same as if their values were expanded and written into exprssions in the `spec` directly. This is
generally most efficient if the values will never change.  On the other hand, `aux_vars` will be 
passed by reference into the compiled results. These values need to be numba-safe objects, so
for instance a regular Python dictionary can't be used, but a numba typed Dict is acceptable.
So long as the data type and dimensionality of the values in `aux_vars` remains constant, the 
actual values can be changed later (i.e. after compilation).

Once we have defined our data tree, we can use it along with the `spec`, to compute the utility
for various alternatives in the choice model.  Sharrow allows us to compile this utility function
into a `Flow`, which can be reused for massive speed gains on later utility evaluations.

In [None]:
flow = tree.setup_flow(spec.Expression)

To use a `Flow` for preparing the array of data that backs the utility
function, we can call the `load()` method. The first time we call `load()`,
it takes a (relatively) long time to evaluate, as the expressions are compiled
and that compiled code is cached to disk.

In [None]:
# TEST

assert flow.tree.aux_vars["short_i_wait_mult"] == 0.75
assert flow.tree.aux_vars["income_breakpoints"][2] == 60000

In [None]:
%time flow.load()

In [None]:
# TEST utility data
assert flow.check_cache_misses(fresh=False)
actual = flow.load()
expected = np.array(
    [
        [9.4, 16.9572, 4.5, 0.0, 1.0],
        [9.32, 14.3628, 4.5, 1.0, 1.0],
        [7.62, 11.0129, 4.5, 1.0, 1.0],
        [4.25, 7.6692, 2.50065, 0.0, 1.0],
        [6.16, 8.2186, 3.387825, 0.0, 1.0],
        [4.86, 4.9288, 4.5, 0.0, 1.0],
        [1.07, 0.0, 0.0, 0.0, 1.0],
        [8.52, 11.615499, 3.260325, 0.0, 1.0],
        [11.74, 16.2798, 3.440325, 0.0, 1.0],
        [10.48, 13.3974, 3.942825, 0.0, 1.0],
    ],
    dtype=np.float32,
)

np.testing.assert_array_almost_equal(actual[:5], expected[:5])
np.testing.assert_array_almost_equal(actual[-5:], expected[-5:])
assert actual.shape == (len(tours), len(spec))

Subsequent calls to `load()` are much faster.

In [None]:
%time flow.load()

In [None]:
# TEST compile flags
flow.load(compile_watch=False)
import pytest

with pytest.raises(AttributeError):
    compiled_recently = (
        flow.compiled_recently
    )  # attribute does not exist if compile_watch flag is off

It's not faster because it's cached the data, but because it's cached the compiled code.
(Setting the `compile_watch` argument to a truthy value will trigger a check of the 
cache files and emit a warning message if recompilation was triggered.)
We can swap out the `tour` node in the tree for a different set of (similarly formatted)
tours, and re-evaluate at that fast speed.

In [None]:
tours_2 = random_tours(seed=43)
tours_2.head()

Note that the flow requires not just a base dataset but a whole DataTree to operate,
so to re-evaluate with a new `tours` we need to make a DataTree with `replace_datasets`.
Fortuntately, this operation is no-copy so it doesn't consume much memory.  If all the 
datasets in a tree are linked by position (instead of by label) this would be almost 
instantaneous, but since our example tree here has tours linked by label it takes just a
moment to rebuild the linkages.

In [None]:
tree_2 = tree.replace_datasets(tour=tours_2)

In [None]:
# TEST
from pytest import approx

assert tree_2.aux_vars["short_i_wait_mult"] == 0.75
assert tree_2.aux_vars["income_breakpoints"][2] == approx(60000)

In [None]:
%time flow.load(tree_2)

In [None]:
# TEST that aux_vars also work with arrays
tree_a = tree_2.replace_datasets(tour=tours)
tree_a.aux_vars["income_breakpoints"] = np.asarray([1, 2, 60000])
actual = flow.load(tree_a)
expected = np.array(
    [
        [9.4, 16.9572, 4.5, 0.0, 1.0],
        [9.32, 14.3628, 4.5, 1.0, 1.0],
        [7.62, 11.0129, 4.5, 1.0, 1.0],
        [4.25, 7.6692, 2.50065, 0.0, 1.0],
        [6.16, 8.2186, 3.387825, 0.0, 1.0],
        [4.86, 4.9288, 4.5, 0.0, 1.0],
        [1.07, 0.0, 0.0, 0.0, 1.0],
        [8.52, 11.615499, 3.260325, 0.0, 1.0],
        [11.74, 16.2798, 3.440325, 0.0, 1.0],
        [10.48, 13.3974, 3.942825, 0.0, 1.0],
    ],
    dtype=np.float32,
)

np.testing.assert_array_almost_equal(actual[:5], expected[:5])
np.testing.assert_array_almost_equal(actual[-5:], expected[-5:])
assert actual.shape == (len(tours), len(spec))

The load function also has some other features, like nicely formatting the output
into a DataFrame.

In [None]:
df = flow.load_dataframe()
df

In [None]:
# TEST df
assert len(df) == len(tours)
pd.testing.assert_index_equal(
    df.columns,
    pd.Index(["Drive Time", "Transit IVT", "Transit Wait Time", "Income", "Constant"]),
)
expected_df_head = pd.read_csv(
    StringIO(
        """,Drive Time,Transit IVT,Transit Wait Time,Income,Constant
0,9.4,16.9572,4.5,0.0,1.0
1,9.32,14.3628,4.5,1.0,1.0
2,7.62,11.0129,4.5,1.0,1.0
3,4.25,7.6692,2.50065,0.0,1.0
4,6.16,8.2186,3.387825,0.0,1.0"""
    ),
    index_col=0,
).astype(np.float32)
pd.testing.assert_frame_equal(df.head(), expected_df_head)

## Linear-in-Parameters Functions

When the `spec` represents a linear-in-parameters utility function, the data 
we get out of the `load()` function represents one matrix in a dot-product, and
the coefficients in the `spec` provide the other matrix.  We might look to 
use the efficient linear algebra algorithms embedded in `np.dot` to compute the
utility, like this:

In [None]:
x = flow.load()
b = spec.iloc[:, 1:].fillna(0).astype(np.float32).values
np.dot(x, b)

But `sharrow` provides a substantially faster option, by embedding
the dot product directly into the compiled code and never instantiating the
full `x` array in memory at all.

In [None]:
%time flow.dot(b)

In [None]:
u = flow.dot(b)
u

In [None]:
# TEST utility
np.testing.assert_array_almost_equal(u, np.dot(x, b))

As before, the compiler runs only the first time we apply the this 
function with this structure, and subsequent runs are faster, even with
different source data.

In [None]:
%time flow.dot(b, source=tree_2)

As for the plain `load` method, the `dot` method also has some formatted output versions.
For example, the `dot_dataarray` returns a `DataArray`.

In [None]:
flow.dot_dataarray(b, source=tree_2)

This works even better if the coefficients are given as a DataArray too, so it 
can harvest dimension names and coordinates as appropriate.

In [None]:
B = xr.DataArray(
    spec.iloc[:, 1:].fillna(0).astype(np.float32), dims=("expressions", "modes")
)
flow.dot_dataarray(B, source=tree_2)

## Multinomial Logit Simulation

The next level of flow evaluation is made by treating the dot-product as a
linear-in-parameters multinomial logit (MNL) utility function, and making simulated
choices based on that model.  To do this, we'll need to provide the random
draws as a function input (which also lets us attach any randomization engine
we prefer, e.g. a reproducible random generator).  For this example, we'll 
create one random (uniform) draw for each tour.

In [None]:
draws = np.random.default_rng(321).random(size=tree.shape[0])

Given those draws, we use the `logit_draws` method to build and apply a 
MNL simulator, which returns to us both the choices and the probability that
was computed for each chosen alternative. 

In [None]:
choices, choice_probs = flow.logit_draws(b, draws)

In [None]:
%time choices, choice_probs = flow.logit_draws(b, draws)

As this is the most complex flow processor,
it takes the longest to compile, but after compilation it runs quite efficiently.
We can see here the whole MNL simulation process for this data requires only a few 
milliseconds more time than just computing the utilities.

In [None]:
choices2, choice_probs2 = flow.logit_draws(b, draws, source=tree_2)

In [None]:
%time choices2, choice_probs2 = flow.logit_draws(b, draws, source=tree_2)

The resulting choices are the index position of the choices, not the labels.

In [None]:
choices

But if we want the labels, it's easy enough to convert these indexes into labels.

In [None]:
B.modes[choices]

In [None]:
# TEST mnl choices
uz = np.exp(flow.dot(b))
uz = uz / uz.sum(1)[:, None]
np.testing.assert_array_almost_equal(
    uz[range(uz.shape[0]), choices.ravel()],
    choice_probs.ravel(),
)

In [None]:
# TEST
choices_darr, choice_probs_darr = flow.logit_draws(b, draws, as_dataarray=True)
assert choices_darr.dims == ("TOURIDX",)
assert choices_darr.shape == (100000,)
assert choice_probs_darr.dims == ("TOURIDX",)
assert choice_probs_darr.shape == (100000,)

## Nested Logit Simulation

Sharrow can also apply nested logit models.  To do so, you'll also need
to install a recent version of *larch* (e.g. `conda install "larch>=5.7.1" -c conda-forge`).

The nesting tree can be defined as usual in Larch, or you can use the
`construct_nesting_tree` convenience function to read in a nesting tree
definition according to the usual ActivitySim yaml notation, like this:

In [None]:
nesting_settings = """
NESTS:
  name: root
  coefficient: coef_nest_root
  alternatives:
      - name: MOTORIZED
        coefficient: coef_nest_motor
        alternatives:
            - DRIVE
            - TRANSIT
      - WALK
"""

import yaml

from sharrow.nested_logit import construct_nesting_tree

nesting_settings = yaml.safe_load(nesting_settings)["NESTS"]
nest_tree = construct_nesting_tree(
    alternatives=spec.columns[1:], nesting_settings=nesting_settings
)

In [None]:
nest_tree

Once the nesting tree is defined, it needs to be converted to operating arrays, using the `as_arrays` method (available in larch 5.7.1 and later).  Since we note estimating a nested logit model and just applying one,
we can give the parameter values as a dictionary instead of a `larch.Model` to link against.

In [None]:
nesting = nest_tree.as_arrays(
    trim=True, parameter_dict={"coef_nest_motor": 0.5, "coef_nest_root": 1.0}
)

This dictionary of arrays can be passed in to the `logit_draws` function to compile a nested logit model
intead of a simple MNL.

In [None]:
%time choices_nl, choice_probs_nl = flow.logit_draws(b, draws, nesting=nesting)

In [None]:
%time choices2_nl, choice2_probs_nl = flow.logit_draws(b, draws, source=tree_2, nesting=nesting)

In [None]:
# TEST
choices2_nl_darr, choice2_probs_nl_darr = flow.logit_draws(
    b, draws, source=tree_2, nesting=nesting, as_dataarray=True
)
assert choices2_nl_darr.dims == ("TOURIDX",)
assert choices2_nl_darr.shape == (100000,)
assert choice2_probs_nl_darr.dims == ("TOURIDX",)
assert choice2_probs_nl_darr.shape == (100000,)

In [None]:
# TEST devolve NL to MNL
choices_nl_1, choice_probs_nl_1 = flow.logit_draws(
    b,
    draws,
    nesting=nest_tree.as_arrays(
        trim=True, parameter_dict={"coef_nest_motor": 1.0, "coef_nest_root": 1.0}
    ),
)
assert (choices_nl_1 == choices).all()
assert choice_probs == approx(choice_probs_nl_1)

For nested logit models, computing just the logsums is faster than generating probabilities (and making choices) so the `logsums=1` argument allows you to short-circuit the computations if you only want the logsums.

In [None]:
flow.logit_draws(b, draws, source=tree_2, nesting=nesting, logsums=1)

In [None]:
# TEST
_ch, _pr, _pc, _ls = flow.logit_draws(
    b, draws, source=tree_2, nesting=nesting, logsums=1
)
assert _ch is None
assert _pr is None
assert _pc is None
assert _ls.size == 100000
np.testing.assert_array_almost_equal(
    _ls[:5], [0.532791, 0.490935, 0.557529, 0.556371, 0.54812]
)
np.testing.assert_array_almost_equal(
    _ls[-5:], [0.452682, 0.465422, 0.554312, 0.525064, 0.515226]
)

_ch, _pr, _pc, _ls = flow.logit_draws(
    b,
    draws,
    source=tree_2,
    nesting=nesting,
    logsums=1,
    as_dataarray=True,
)
assert _ch is None
assert _pr is None
assert _pc is None
assert _ls.size == 100000
assert _ls.dims == ("TOURIDX",)
assert _ls.shape == (100000,)

In [None]:
# TEST masking
masker = np.zeros(draws.shape, dtype=np.int8)
masker[::2] = 1
_ch_m, _pr_m, _pc_m, _ls_m = flow.logit_draws(
    b, draws, source=tree_2, nesting=nesting, logsums=1, mask=masker
)

assert _ls_m == approx(np.where(masker, _ls, 0))
assert (_ch_m, _pr_m, _pc_m) == (None, None, None)

Note that for consistency, the choices, probabilities of choices,
and pick count arrays are still returned as the first three elements
of the returned tuple, but they're all zero-size empty arrays.

To get *both* the logsums and the choices, set `logsums=2`.

In [None]:
flow.logit_draws(b, draws, source=tree_2, nesting=nesting, logsums=2)

In [None]:
# TEST
_ch, _pr, _pc, _ls = flow.logit_draws(
    b, draws, source=tree_2, nesting=nesting, logsums=2
)
assert _ch.size == 100000
assert _pr.size == 100000
assert _pc is None
assert _ls.size == 100000
np.testing.assert_array_almost_equal(_ch[:5], [1, 2, 1, 1, 1])
np.testing.assert_array_almost_equal(_ch[-5:], [0, 1, 0, 1, 0])
np.testing.assert_array_almost_equal(
    _pr[:5], [0.393454, 0.16956, 0.38384, 0.384285, 0.387469]
)
np.testing.assert_array_almost_equal(
    _pr[-5:], [0.503606, 0.420874, 0.478898, 0.396506, 0.468742]
)
np.testing.assert_array_almost_equal(
    _ls[:5], [0.532791, 0.490935, 0.557529, 0.556371, 0.54812]
)
np.testing.assert_array_almost_equal(
    _ls[-5:], [0.452682, 0.465422, 0.554312, 0.525064, 0.515226]
)
_ch, _pr, _pc, _ls = flow.logit_draws(
    b, draws, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True
)
assert _ch.size == 100000
assert _ch.dims == ("TOURIDX",)
assert _ch.shape == (100000,)
assert _pr.size == 100000
assert _pr.dims == ("TOURIDX",)
assert _pr.shape == (100000,)
assert _ls.size == 100000
assert _ls.dims == ("TOURIDX",)
assert _ls.shape == (100000,)

In [None]:
# TEST
draws_many = np.random.default_rng(42).random(size=(tree.shape[0], 5))
_ch, _pr, _pc, _ls = flow.logit_draws(
    b, draws_many, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True
)
assert _ch.dims == ("TOURIDX", "DRAW")
assert _ch.shape == (100000, 5)
assert _pr.dims == ("TOURIDX", "DRAW")
assert _pr.shape == (100000, 5)
assert _ls.dims == ("TOURIDX",)
assert _ls.shape == (100000,)
assert _pc is None

_ch, _pr, _pc, _ls = flow.logit_draws(
    b,
    draws_many,
    source=tree_2,
    nesting=nesting,
    logsums=2,
    as_dataarray=True,
    pick_counted=True,
)
assert _ch.dims == ("TOURIDX", "DRAW")
assert _ch.shape == (100000, 5)
assert _pr.dims == ("TOURIDX", "DRAW")
assert _pr.shape == (100000, 5)
assert _ls.dims == ("TOURIDX",)
assert _ls.shape == (100000,)
assert _pc.dims == ("TOURIDX", "DRAW")
assert _pc.shape == (100000, 5)

In [None]:
# TEST masking
masker = np.zeros(tree.shape[0], dtype=np.int8)
masker[::3] = 1

_ch_m, _pr_m, _pc_m, _ls_m = flow.logit_draws(
    b,
    draws_many,
    source=tree_2,
    nesting=nesting,
    logsums=2,
    as_dataarray=True,
    mask=masker,
    pick_counted=True,
)

assert (_ch_m.values == (np.where(np.expand_dims(masker, -1), _ch, -1))).all()
assert (_pr_m.values == (np.where(np.expand_dims(masker, -1), _pr, 0))).all()
assert (_pc_m.values == (np.where(np.expand_dims(masker, -1), _pc, 0))).all()
assert (_ls_m.values == (np.where(masker, _ls, 0))).all()

## Batch Simulation

Suppose we want to compute logsums not just for one destination, but for many destinations.  We can construct a `Dataset` with two dimensions to use at the top of our `DataTree`, one for the tours and one for the candidate destinations.

In [None]:
tour_by_dest = tree.subspaces["tour"]
tour_by_dest = tour_by_dest.assign_coords(
    {"CAND_DEST": xr.DataArray(np.arange(25), dims="CAND_DEST")}
)
tour_by_dest

Then we can create a very similar DataTree as above, using this two dimension root Dataset, but we will point to our destination zones from the new tour dimension. and then create a flow from that.

In [None]:
wide_tree = sh.DataTree(
    tour=tour_by_dest,
    person=persons,
    hh=households,
    odt_skims=skims,
    dot_skims=skims,
    relationships=(
        "tour.PERID @ person.PERID",
        "person.household_id @ hh.HHID",
        "hh.TAZ @ odt_skims.otaz",
        "tour.CAND_DEST -> odt_skims.dtaz",
        "tour.out_time_period @ odt_skims.time_period",
        "tour.CAND_DEST -> dot_skims.otaz",
        "hh.TAZ @ dot_skims.dtaz",
        "tour.in_time_period @ dot_skims.time_period",
    ),
    extra_vars={
        "shortwait": 3,
        "one": 1,
    },
    aux_vars={
        "short_i_wait_mult": 0.75,
        "income_breakpoints": income_breakpoints,
    },
    dim_order=("TOURIDX", "CAND_DEST"),
)
wide_flow = wide_tree.setup_flow(spec.Expression)

In [None]:
wide_logsums = wide_flow.logit_draws(b, logsums=1, compile_watch="simple")[-1]

In [None]:
%time wide_logsums = wide_flow.logit_draws(b, logsums=1, compile_watch="simple")[-1]
wide_logsums

In [None]:
# TEST
np.testing.assert_array_almost_equal(
    wide_logsums[:5, :5],
    np.array(
        [
            [0.759222, 0.75862, 0.744936, 0.758251, 0.737007],
            [0.671698, 0.671504, 0.663015, 0.661482, 0.667133],
            [0.670188, 0.678498, 0.687647, 0.691152, 0.715783],
            [0.760743, 0.769123, 0.763733, 0.784487, 0.802356],
            [0.73474, 0.743051, 0.751439, 0.754731, 0.778121],
        ],
        dtype=np.float32,
    ),
)
np.testing.assert_array_almost_equal(
    wide_logsums[-5:, -5:],
    np.array(
        [
            [0.719523, 0.755152, 0.739368, 0.762664, 0.764388],
            [0.740303, 0.678783, 0.649964, 0.694407, 0.681555],
            [0.758865, 0.663663, 0.637266, 0.673351, 0.65875],
            [0.765125, 0.706478, 0.676878, 0.717814, 0.713912],
            [0.73348, 0.683626, 0.647698, 0.69146, 0.673006],
        ],
        dtype=np.float32,
    ),
)

In [None]:
# TEST
np.testing.assert_array_almost_equal(
    wide_logsums[np.arange(len(tours)), tours["dest_taz_idx"].to_numpy()],
    flow.logit_draws(b, logsums=1)[-1],
)

In [None]:
# TEST
wide_logsums_ = wide_flow.logit_draws(
    b, logsums=1, compile_watch=True, as_dataarray=True
)[-1]
assert wide_logsums_.dims == ("TOURIDX", "CAND_DEST")
assert wide_logsums_.shape == (100000, 25)

In [None]:
# TEST
wide_draws = np.random.default_rng(42).random(size=wide_tree.shape)
with pytest.warns(sh.CacheMissWarning):
    wide_logsums_plus = wide_flow.logit_draws(
        b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws
    )
assert wide_logsums_plus[0].dims == ("TOURIDX", "CAND_DEST")
assert wide_logsums_plus[0].shape == (100000, 25)
assert wide_logsums_plus[3].dims == ("TOURIDX", "CAND_DEST")
assert wide_logsums_plus[3].shape == (100000, 25)

In [None]:
# TEST
wide_draws = np.random.default_rng(42).random(size=wide_tree.shape + (2,))
wide_logsums_plus = wide_flow.logit_draws(
    b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws
)
assert wide_logsums_plus[0].dims == ("TOURIDX", "CAND_DEST", "DRAW")
assert wide_logsums_plus[0].shape == (100000, 25, 2)
assert wide_logsums_plus[3].dims == ("TOURIDX", "CAND_DEST")
assert wide_logsums_plus[3].shape == (100000, 25)

In [None]:
# TEST masking
mask = np.zeros(wide_tree.shape, dtype=np.int8)
mask[::7] = 1
with pytest.warns(sh.CacheMissWarning):
    wide_logsums_mask = wide_flow.logit_draws(
        b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws, mask=mask
    )
assert wide_logsums_mask[0].dims == ("TOURIDX", "CAND_DEST", "DRAW")
assert wide_logsums_mask[0].shape == (100000, 25, 2)
assert wide_logsums_mask[3].dims == ("TOURIDX", "CAND_DEST")
assert wide_logsums_mask[3].shape == (100000, 25)

assert (
    wide_logsums_plus[0].where(np.expand_dims(mask, -1), -1) == wide_logsums_mask[0]
).all()
assert (
    wide_logsums_plus[1].where(np.expand_dims(mask, -1), 0) == wide_logsums_mask[1]
).all()
assert (wide_logsums_plus[3].where(mask, 0) == wide_logsums_mask[3]).all()

In [None]:
# TEST masking performance
import timeit
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("error")
    masked_time = timeit.timeit(
        lambda: wide_flow.logit_draws(
            b,
            logsums=2,
            compile_watch=True,
            as_dataarray=True,
            draws=wide_draws,
            mask=mask,
        ),
        number=1,
    )
    raw_time = timeit.timeit(
        lambda: wide_flow.logit_draws(
            b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws
        ),
        number=1,
    )
assert masked_time < raw_time  # generous, should be nearly 7 times faster
assert len(wide_flow.cache_misses["_imnl_plus1d"]) == 3