Sharrow Basics#

This notebook provides a short walkthrough of some of the basic features of the sharrow library.

from io import StringIO

import numba as nb
import numpy as np
import pandas as pd
import xarray as xr

import sharrow as sh

sh.__version__
'2.8.2'

Example Data#

We’ll begin by importing some example data to work with. We’ll be using some test data taken from the MTC example in the ActivitySim project, including tables of data for households and persons, as well as a set of skims containing transportation level of service information for travel around a tiny slice of San Francisco.

The households and persons are typical tabular data, and each can be read in and stored as a pandas.DataFrame.

households = sh.example_data.get_households()
households.head()
TAZ SERIALNO PUMA5 income PERSONS HHT UNITTYPE NOC BLDGSZ TENURE ... hschpred hschdriv htypdwel hownrent hadnwst hadwpst hadkids bucketBin originalPUMA hmultiunit
HHID
2717868 25 2715386 2202 361000 2 1 0 0 9 1 ... 0 0 2 1 0 0 0 3 2202 1
763899 6 5360279 2203 59220 1 4 0 0 9 3 ... 0 0 2 2 0 0 0 4 2203 1
2222791 9 77132 2203 197000 2 2 0 0 9 1 ... 0 0 2 1 0 0 1 5 2203 1
112477 17 3286812 2203 2200 1 6 0 0 8 3 ... 0 0 2 2 0 0 0 7 2203 1
370491 21 6887183 2203 16500 3 1 0 1 8 3 ... 1 0 2 2 0 0 0 7 2203 1

5 rows × 46 columns

persons = sh.example_data.get_persons()
persons.head()
household_id age RELATE ESR GRADE PNUM PAUG DDP sex WEEKS HOURS MSP POVERTY EARNS pagecat pemploy pstudent ptype padkid
PERID
25671 25671 47 1 6 0 1 0 0 1 0 0 6 39 0 6 3 3 4 2
25675 25675 27 1 6 7 1 0 0 2 52 40 2 84 7200 5 3 2 3 2
25678 25678 30 1 6 0 1 0 0 2 0 0 6 84 0 5 3 3 4 2
25683 25683 23 1 6 0 1 0 0 1 0 0 6 1 0 4 3 3 4 2
25684 25684 52 1 6 0 1 0 0 1 0 0 6 94 0 7 3 3 4 2

The skims, on the other hand, are not just simple tabular data, but rather a multi-dimensional representation of the transportation system, indexed by origin. destination, and time of day. Rather than using a single DataFrame for this data, we store it as a multi-dimensional xarray.Dataset.

skims = sh.example_data.get_skims()
skims
<xarray.Dataset> Size: 2MB
Dimensions:               (otaz: 25, dtaz: 25, time_period: 5)
Coordinates:
  * dtaz                  (dtaz) int64 200B 1 2 3 4 5 6 7 ... 20 21 22 23 24 25
  * otaz                  (otaz) int64 200B 1 2 3 4 5 6 7 ... 20 21 22 23 24 25
  * time_period           (time_period) <U2 40B 'EA' 'AM' 'MD' 'PM' 'EV'
Data variables: (12/170)
    DIST                  (otaz, dtaz) float32 2kB dask.array<chunksize=(25, 25), meta=np.ndarray>
    DISTBIKE              (otaz, dtaz) float32 2kB dask.array<chunksize=(25, 25), meta=np.ndarray>
    DISTWALK              (otaz, dtaz) float32 2kB dask.array<chunksize=(25, 25), meta=np.ndarray>
    DRV_COM_WLK_BOARDS    (otaz, dtaz, time_period) float32 12kB dask.array<chunksize=(25, 25, 5), meta=np.ndarray>
    DRV_COM_WLK_DDIST     (otaz, dtaz, time_period) float32 12kB dask.array<chunksize=(25, 25, 5), meta=np.ndarray>
    DRV_COM_WLK_DTIM      (otaz, dtaz, time_period) float32 12kB dask.array<chunksize=(25, 25, 5), meta=np.ndarray>
    ...                    ...
    WLK_TRN_WLK_IVT       (otaz, dtaz, time_period) float32 12kB dask.array<chunksize=(25, 25, 5), meta=np.ndarray>
    WLK_TRN_WLK_IWAIT     (otaz, dtaz, time_period) float32 12kB dask.array<chunksize=(25, 25, 5), meta=np.ndarray>
    WLK_TRN_WLK_WACC      (otaz, dtaz, time_period) float32 12kB dask.array<chunksize=(25, 25, 5), meta=np.ndarray>
    WLK_TRN_WLK_WAUX      (otaz, dtaz, time_period) float32 12kB dask.array<chunksize=(25, 25, 5), meta=np.ndarray>
    WLK_TRN_WLK_WEGR      (otaz, dtaz, time_period) float32 12kB dask.array<chunksize=(25, 25, 5), meta=np.ndarray>
    WLK_TRN_WLK_XWAIT     (otaz, dtaz, time_period) float32 12kB dask.array<chunksize=(25, 25, 5), meta=np.ndarray>

For tabular data, sharrow can be provided either pandas DataFrames or xarray Datasets, but to ensure consistency the former are converted into the latter automatically when they are used with sharrow. You can also easily manually make the conversion:

xr.Dataset(persons)
<xarray.Dataset> Size: 1MB
Dimensions:       (PERID: 8212)
Coordinates:
  * PERID         (PERID) int64 66kB 25671 25675 25678 ... 7554887 7554903
Data variables: (12/19)
    household_id  (PERID) int64 66kB 25671 25675 25678 ... 2863552 2863568
    age           (PERID) int64 66kB 47 27 30 23 52 19 54 ... 82 68 68 93 76 82
    RELATE        (PERID) int64 66kB 1 1 1 1 1 1 1 1 ... 22 22 22 22 22 22 22 22
    ESR           (PERID) int64 66kB 6 6 6 6 6 6 6 6 6 6 ... 6 6 6 6 6 6 6 6 6 6
    GRADE         (PERID) int64 66kB 0 7 0 0 0 6 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
    PNUM          (PERID) int64 66kB 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1
    ...            ...
    EARNS         (PERID) int64 66kB 0 7200 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0
    pagecat       (PERID) int64 66kB 6 5 5 4 7 4 7 6 7 7 ... 8 8 9 9 9 8 8 9 8 9
    pemploy       (PERID) int64 66kB 3 3 3 3 3 3 3 3 3 3 ... 3 3 3 3 3 3 3 3 3 3
    pstudent      (PERID) int64 66kB 3 2 3 3 3 2 3 3 3 3 ... 3 3 3 3 3 3 3 3 3 3
    ptype         (PERID) int64 66kB 4 3 4 4 4 3 4 4 4 4 ... 5 5 5 5 5 5 5 5 5 5
    padkid        (PERID) int64 66kB 2 2 2 2 2 2 2 2 2 2 ... 2 2 2 2 2 2 2 2 2 2

Suppose we’re wanting to simulate a tour mode choice. Normally we’d probably have run through a bunch of different models to generate these tours and their destinations first, but let’s just skip that for now and make up some random data to work with. We’ll just randomly choose (with replacement) 100,000 people, and send them to 100,000 zones, with random outbound and inbound time periods.

def random_tours(n_tours=100_000, seed=42):
    rng = np.random.default_rng(seed)
    n_zones = skims.dims["dtaz"]
    return pd.DataFrame(
        {
            "PERID": rng.choice(persons.index, size=n_tours),
            "dest_taz_idx": rng.choice(n_zones, size=n_tours),
            "out_time_period": rng.choice(skims.time_period, size=n_tours),
            "in_time_period": rng.choice(skims.time_period, size=n_tours),
        }
    ).rename_axis("TOURIDX")


tours = random_tours()
tours.head()
/tmp/ipykernel_2769/399202930.py:3: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  n_zones = skims.dims["dtaz"]
PERID dest_taz_idx out_time_period in_time_period
TOURIDX
0 111378 18 EV AM
1 5058053 22 EV EV
2 3608229 14 EV EV
3 1874724 10 EV PM
4 1774303 0 EV PM

Of note in this table, we include include destination TAZ’s by index (position) not label, so we can observe a TAZ index of 0 even though the first TAZ ID is 1.

Spec Files#

Now that we’ve got our tours to work with, we’ll also need an expression “spec” file that defines the utility function terms and coefficients. Following the ActivitySim format, we can write a mini-spec file as appears below. Each line of this CSV file has an expression that can be evaluated in the context of the various tables and datasets shown above, plus a set of coefficients that apply for that expression across various modal alternatives (drive, walk, and transit in this example).

mini_spec = """
Label,Expression,DRIVE,WALK,TRANSIT
Drive Time,odt_skims['SOV_TIME'] + dot_skims['SOV_TIME'],-0.0134,,
Transit IVT,(odt_skims['WLK_LOC_WLK_TOTIVT']/100 + dot_skims['WLK_LOC_WLK_TOTIVT']/100),,,-0.0134
Transit Wait Time,short_i_wait_mult * ((odt_skims['WLK_LOC_WLK_IWAIT']/100).clip(upper=shortwait) + (dot_skims['WLK_LOC_WLK_IWAIT']/100).clip(upper=shortwait)),,,-0.0134
Income,hh.income > income_breakpoints[2],,-0.2,
Constant,one,,-0.4,-0.55
"""

We’ll use pandas to load these values into a DataFrame.

spec = pd.read_csv(StringIO(mini_spec), index_col="Label")
spec
Expression DRIVE WALK TRANSIT
Label
Drive Time odt_skims['SOV_TIME'] + dot_skims['SOV_TIME'] -0.0134 NaN NaN
Transit IVT (odt_skims['WLK_LOC_WLK_TOTIVT']/100 + dot_ski... NaN NaN -0.0134
Transit Wait Time short_i_wait_mult * ((odt_skims['WLK_LOC_WLK_I... NaN NaN -0.0134
Income hh.income > income_breakpoints[2] NaN -0.2 NaN
Constant one NaN -0.4 -0.5500

Data Trees and Flows#

Then, it’s time to prepare our data. We’ll create a DataTree that defines the relationships among all the datasets we’re working with. This is a tree in the mathematical sense, with nodes referencing the datasets and edges representing the relationships.

income_breakpoints = nb.typed.Dict.empty(nb.types.int32, nb.types.int32)
income_breakpoints[0] = 15000
income_breakpoints[1] = 30000
income_breakpoints[2] = 60000

tree = sh.DataTree(
    tour=tours,
    person=persons,
    hh=households,
    odt_skims=skims,
    dot_skims=skims,
    relationships=(
        "tour.PERID @ person.PERID",
        "person.household_id @ hh.HHID",
        "hh.TAZ @ odt_skims.otaz",
        "tour.dest_taz_idx -> odt_skims.dtaz",
        "tour.out_time_period @ odt_skims.time_period",
        "tour.dest_taz_idx -> dot_skims.otaz",
        "hh.TAZ @ dot_skims.dtaz",
        "tour.in_time_period @ dot_skims.time_period",
    ),
    extra_vars={
        "shortwait": 3,
        "one": 1,
    },
    aux_vars={
        "short_i_wait_mult": 0.75,
        "income_breakpoints": income_breakpoints,
    },
)
/home/runner/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/typed/typeddict.py:34: NumbaTypeSafetyWarning: unsafe cast from int64 to int32. Precision may be lost.
  d[key] = value

The first named dataset we include, tour, is by default the root node of this data tree. We then can define an arbitrary number of other named data nodes. Here, we add person, hh, odt_skims and odt_skims. Note that these last two are actually two different names for the same underlying dataset, and for each name we will next define a unique set of relationships.

All data nodes in this tree are stored as Dataset objects. We can give a pandas DataFrame in this contructor instead, but it will be automatically converted into a one-dimension Dataset. The conversion is no-copy if possible (and it is usually possible) so no additional memory is consumed in the conversion.

The relationships defines links of the data tree. Each relationship maps a particular variable in a named upstream dataset to a particular dimension of a named downstream dataset. For example, "person.household_id @ hh.HHID" tells the tree that the household_id variable in the person dataset contains labels (@) that map to the HHID dimension of the hh dataset.

In addition to mapping by label, we can also map by position, by using the -> operator in the relationship string instead of @. In the example above, we map the tour destination TAZ’s in this manner, as the dest_taz_idx variable in the tours dataset contains positional references instead of labels.

A special case for the relationship mapping is available when the source varibable in the upstream dataset is explicitly categorical. In this case, sharrow checks that the categories exactly match the labels in the referenced downstream dataset dimension, and that there are no missing categorical values. If they do match and there are no missing values, the code points of the categories are used as positional mapping references, which is both memory and runtime efficient. If they don’t match, an error is raised, as it is presumed that the user has made a mistake… in theory sharrow could unravel the category values and do the mapping by label, but this would be a cumbersome operation, contrary to the efficiency goals of the library.

Lastly, our tree definition includes a few named constants, that are just fixed values defined in a separate dictionary. These are shown in two groups, extra_vars and aux_vars. The values in extra_vars get hard-coded into the compiled results, effectively the same as if their values were expanded and written into exprssions in the spec directly. This is generally most efficient if the values will never change. On the other hand, aux_vars will be passed by reference into the compiled results. These values need to be numba-safe objects, so for instance a regular Python dictionary can’t be used, but a numba typed Dict is acceptable. So long as the data type and dimensionality of the values in aux_vars remains constant, the actual values can be changed later (i.e. after compilation).

Once we have defined our data tree, we can use it along with the spec, to compute the utility for various alternatives in the choice model. Sharrow allows us to compile this utility function into a Flow, which can be reused for massive speed gains on later utility evaluations.

flow = tree.setup_flow(spec.Expression)

To use a Flow for preparing the array of data that backs the utility function, we can call the load() method. The first time we call load(), it takes a (relatively) long time to evaluate, as the expressions are compiled and that compiled code is cached to disk.

%time flow.load()
CPU times: user 2.69 s, sys: 27.2 ms, total: 2.72 s
Wall time: 2.7 s
array([[ 9.4      , 16.9572   ,  4.5      ,  0.       ,  1.       ],
       [ 9.32     , 14.3628   ,  4.5      ,  1.       ,  1.       ],
       [ 7.62     , 11.0129   ,  4.5      ,  1.       ,  1.       ],
       ...,
       [ 8.52     , 11.6154995,  3.260325 ,  0.       ,  1.       ],
       [11.74     , 16.2798   ,  3.440325 ,  0.       ,  1.       ],
       [10.48     , 13.3974   ,  3.942825 ,  0.       ,  1.       ]],
      dtype=float32)

Subsequent calls to load() are much faster.

%time flow.load()
CPU times: user 42.2 ms, sys: 210 µs, total: 42.4 ms
Wall time: 12.6 ms
array([[ 9.4      , 16.9572   ,  4.5      ,  0.       ,  1.       ],
       [ 9.32     , 14.3628   ,  4.5      ,  1.       ,  1.       ],
       [ 7.62     , 11.0129   ,  4.5      ,  1.       ,  1.       ],
       ...,
       [ 8.52     , 11.6154995,  3.260325 ,  0.       ,  1.       ],
       [11.74     , 16.2798   ,  3.440325 ,  0.       ,  1.       ],
       [10.48     , 13.3974   ,  3.942825 ,  0.       ,  1.       ]],
      dtype=float32)

It’s not faster because it’s cached the data, but because it’s cached the compiled code. (Setting the compile_watch argument to a truthy value will trigger a check of the cache files and emit a warning message if recompilation was triggered.) We can swap out the tour node in the tree for a different set of (similarly formatted) tours, and re-evaluate at that fast speed.

tours_2 = random_tours(seed=43)
tours_2.head()
/tmp/ipykernel_2769/399202930.py:3: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  n_zones = skims.dims["dtaz"]
PERID dest_taz_idx out_time_period in_time_period
TOURIDX
0 2566803 6 EA AM
1 3596408 6 MD MD
2 1631117 8 EV EA
3 29658 13 EA EA
4 3138090 5 PM EA

Note that the flow requires not just a base dataset but a whole DataTree to operate, so to re-evaluate with a new tours we need to make a DataTree with replace_datasets. Fortuntately, this operation is no-copy so it doesn’t consume much memory. If all the datasets in a tree are linked by position (instead of by label) this would be almost instantaneous, but since our example tree here has tours linked by label it takes just a moment to rebuild the linkages.

tree_2 = tree.replace_datasets(tour=tours_2)
%time flow.load(tree_2)
CPU times: user 24.9 ms, sys: 3.35 ms, total: 28.3 ms
Wall time: 12.9 ms
array([[ 6.5299997,  9.7043   ,  4.3533   ,  0.       ,  1.       ],
       [ 4.91     ,  2.6404002,  1.16565  ,  1.       ,  1.       ],
       [ 4.8900003,  2.2564   ,  4.1078253,  0.       ,  1.       ],
       ...,
       [ 4.25     ,  7.6692   ,  2.50065  ,  0.       ,  1.       ],
       [ 7.3999996, 12.2662   ,  3.0513   ,  0.       ,  1.       ],
       [ 8.91     , 10.9822   ,  4.5      ,  0.       ,  1.       ]],
      dtype=float32)

The load function also has some other features, like nicely formatting the output into a DataFrame.

df = flow.load_dataframe()
df
Drive Time Transit IVT Transit Wait Time Income Constant
0 9.40 16.957199 4.500000 0.0 1.0
1 9.32 14.362800 4.500000 1.0 1.0
2 7.62 11.012900 4.500000 1.0 1.0
3 4.25 7.669200 2.500650 0.0 1.0
4 6.16 8.218600 3.387825 0.0 1.0
... ... ... ... ... ...
99995 4.86 4.928800 4.500000 0.0 1.0
99996 1.07 0.000000 0.000000 0.0 1.0
99997 8.52 11.615499 3.260325 0.0 1.0
99998 11.74 16.279800 3.440325 0.0 1.0
99999 10.48 13.397400 3.942825 0.0 1.0

100000 rows × 5 columns

Linear-in-Parameters Functions#

When the spec represents a linear-in-parameters utility function, the data we get out of the load() function represents one matrix in a dot-product, and the coefficients in the spec provide the other matrix. We might look to use the efficient linear algebra algorithms embedded in np.dot to compute the utility, like this:

x = flow.load()
b = spec.iloc[:, 1:].fillna(0).astype(np.float32).values
np.dot(x, b)
array([[-0.12595999, -0.4       , -0.83752644],
       [-0.124888  , -0.6       , -0.80276155],
       [-0.10210799, -0.6       , -0.7578729 ],
       ...,
       [-0.114168  , -0.4       , -0.74933606],
       [-0.157316  , -0.4       , -0.8142497 ],
       [-0.14043199, -0.4       , -0.782359  ]], dtype=float32)

But sharrow provides a substantially faster option, by embedding the dot product directly into the compiled code and never instantiating the full x array in memory at all.

%time flow.dot(b)
CPU times: user 1.6 s, sys: 251 ms, total: 1.85 s
Wall time: 1.55 s
array([[-0.12595999, -0.4       , -0.83752644],
       [-0.124888  , -0.6       , -0.80276155],
       [-0.10210799, -0.6       , -0.7578729 ],
       ...,
       [-0.114168  , -0.4       , -0.74933606],
       [-0.157316  , -0.4       , -0.8142497 ],
       [-0.14043199, -0.4       , -0.782359  ]], dtype=float32)
u = flow.dot(b)
u
array([[-0.12595999, -0.4       , -0.83752644],
       [-0.124888  , -0.6       , -0.80276155],
       [-0.10210799, -0.6       , -0.7578729 ],
       ...,
       [-0.114168  , -0.4       , -0.74933606],
       [-0.157316  , -0.4       , -0.8142497 ],
       [-0.14043199, -0.4       , -0.782359  ]], dtype=float32)

As before, the compiler runs only the first time we apply the this function with this structure, and subsequent runs are faster, even with different source data.

%time flow.dot(b, source=tree_2)
CPU times: user 39.4 ms, sys: 31.2 ms, total: 70.6 ms
Wall time: 17.7 ms
array([[-0.087502  , -0.4       , -0.73837185],
       [-0.065794  , -0.6       , -0.6010011 ],
       [-0.065526  , -0.4       , -0.6352806 ],
       ...,
       [-0.05695   , -0.4       , -0.68627596],
       [-0.09915999, -0.4       , -0.7552545 ],
       [-0.119394  , -0.4       , -0.7574615 ]], dtype=float32)

As for the plain load method, the dot method also has some formatted output versions. For example, the dot_dataarray returns a DataArray.

flow.dot_dataarray(b, source=tree_2)
<xarray.DataArray (TOURIDX: 100000, ALT_COL: 3)> Size: 1MB
array([[-0.087502  , -0.4       , -0.73837185],
       [-0.065794  , -0.6       , -0.6010011 ],
       [-0.065526  , -0.4       , -0.6352806 ],
       ...,
       [-0.05695   , -0.4       , -0.68627596],
       [-0.09915999, -0.4       , -0.7552545 ],
       [-0.119394  , -0.4       , -0.7574615 ]], dtype=float32)
Coordinates:
  * TOURIDX  (TOURIDX) int64 800kB 0 1 2 3 4 5 ... 99995 99996 99997 99998 99999
Dimensions without coordinates: ALT_COL

This works even better if the coefficients are given as a DataArray too, so it can harvest dimension names and coordinates as appropriate.

B = xr.DataArray(
    spec.iloc[:, 1:].fillna(0).astype(np.float32), dims=("expressions", "modes")
)
flow.dot_dataarray(B, source=tree_2)
<xarray.DataArray (TOURIDX: 100000, modes: 3)> Size: 1MB
array([[-0.087502  , -0.4       , -0.73837185],
       [-0.065794  , -0.6       , -0.6010011 ],
       [-0.065526  , -0.4       , -0.6352806 ],
       ...,
       [-0.05695   , -0.4       , -0.68627596],
       [-0.09915999, -0.4       , -0.7552545 ],
       [-0.119394  , -0.4       , -0.7574615 ]], dtype=float32)
Coordinates:
  * TOURIDX  (TOURIDX) int64 800kB 0 1 2 3 4 5 ... 99995 99996 99997 99998 99999
Dimensions without coordinates: modes

Multinomial Logit Simulation#

The next level of flow evaluation is made by treating the dot-product as a linear-in-parameters multinomial logit (MNL) utility function, and making simulated choices based on that model. To do this, we’ll need to provide the random draws as a function input (which also lets us attach any randomization engine we prefer, e.g. a reproducible random generator). For this example, we’ll create one random (uniform) draw for each tour.

draws = np.random.default_rng(321).random(size=tree.shape[0])
/home/runner/work/sharrow/sharrow/sharrow/relationships.py:393: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  self.root_dataset.dims[i] for i in dim_order if i not in self.dim_exclude

Given those draws, we use the logit_draws method to build and apply a MNL simulator, which returns to us both the choices and the probability that was computed for each chosen alternative.

choices, choice_probs = flow.logit_draws(b, draws)
%time choices, choice_probs = flow.logit_draws(b, draws)
CPU times: user 55.4 ms, sys: 3.72 ms, total: 59.1 ms
Wall time: 20.3 ms

As this is the most complex flow processor, it takes the longest to compile, but after compilation it runs quite efficiently. We can see here the whole MNL simulation process for this data requires only a few milliseconds more time than just computing the utilities.

choices2, choice_probs2 = flow.logit_draws(b, draws, source=tree_2)
%time choices2, choice_probs2 = flow.logit_draws(b, draws, source=tree_2)
CPU times: user 69.5 ms, sys: 0 ns, total: 69.5 ms
Wall time: 17.8 ms

The resulting choices are the index position of the choices, not the labels.

choices
array([1, 2, 1, ..., 0, 1, 0], dtype=int8)

But if we want the labels, it’s easy enough to convert these indexes into labels.

B.modes[choices]
<xarray.DataArray 'modes' (modes: 100000)> Size: 800kB
array(['WALK', 'TRANSIT', 'WALK', ..., 'DRIVE', 'WALK', 'DRIVE'], dtype=object)
Coordinates:
  * modes    (modes) object 800kB 'WALK' 'TRANSIT' 'WALK' ... 'WALK' 'DRIVE'

Nested Logit Simulation#

Sharrow can also apply nested logit models. To do so, you’ll also need to install a recent version of larch (e.g. conda install "larch>=5.7.1" -c conda-forge).

The nesting tree can be defined as usual in Larch, or you can use the construct_nesting_tree convenience function to read in a nesting tree definition according to the usual ActivitySim yaml notation, like this:

nesting_settings = """
NESTS:
  name: root
  coefficient: coef_nest_root
  alternatives:
      - name: MOTORIZED
        coefficient: coef_nest_motor
        alternatives:
            - DRIVE
            - TRANSIT
      - WALK
"""

import yaml

from sharrow.nested_logit import construct_nesting_tree

nesting_settings = yaml.safe_load(nesting_settings)["NESTS"]
nest_tree = construct_nesting_tree(
    alternatives=spec.columns[1:], nesting_settings=nesting_settings
)
JAX not found. Some functionality will be unavailable.
/home/runner/miniconda3/envs/testing-env/lib/python3.10/site-packages/larch/model/numbamodel.py:25: UserWarning: 

#### larch v6 is experimental, and not feature-complete ####
the first time you import on a new system, this package will
compile optimized binaries for your specific machine,  which
may take a little while, please be patient ...

  warnings.warn(  ## Good news, everyone! ##  )
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[44], line 19
     16 from sharrow.nested_logit import construct_nesting_tree
     18 nesting_settings = yaml.safe_load(nesting_settings)["NESTS"]
---> 19 nest_tree = construct_nesting_tree(
     20     alternatives=spec.columns[1:], nesting_settings=nesting_settings
     21 )

File ~/work/sharrow/sharrow/sharrow/nested_logit.py:161, in construct_nesting_tree(alternatives, nesting_settings)
    143 """
    144 Construct a larch NestingTree from ActivitySim settings.
    145 
   (...)
    158 NestingTree
    159 """
    160 try:
--> 161     from larch.model.tree import NestingTree
    162 except ImportError:
    163     raise ImportError("larch is required to construct nesting trees") from None

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/larch/__init__.py:14
     12 from .model import mixtures
     13 from .model.basemodel import BaseModel
---> 14 from .model.jaxmodel import Model
     15 from .model.latent_class import LatentClass, MixedLatentClass
     16 from .model.param_core import ParameterBucket

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/larch/model/jaxmodel.py:14
     12 from ..folding import fold_dataset
     13 from ..optimize import OptimizeMixin
---> 14 from .numbamodel import NumbaModel
     16 logger = logging.getLogger(__name__)
     19 def _get_jnp_array(dataset, name):

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/larch/model/numbamodel.py:740
    736     d_loglike[:] += d_penalty
    737     bhhh[:] += np.outer(d_penalty, d_penalty)
--> 740 _numba_penalty_vectorized = guvectorize(
    741     _type_signatures("fffffFff"),
    742     ("(params),(params),(params),(),()->(params,params),(params),()"),
    743     nopython=True,
    744     fastmath=True,
    745     target="parallel",
    746     cache=True,
    747 )(
    748     _numba_penalty,
    749 )
    752 def model_co_slots(data_provider: Dataset, model: _BaseModel, dtype=np.float64):
    753     len_co = sum(len(_) for _ in model.utility_co.values())

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/np/ufunc/decorators.py:203, in guvectorize.<locals>.wrap(func)
    201 guvec = GUVectorize(func, signature, **kwargs)
    202 for fty in ftylist:
--> 203     guvec.add(fty)
    204 if len(ftylist) > 0:
    205     guvec.disable_compile()

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/np/ufunc/ufuncbuilder.py:258, in _BaseUFuncBuilder.add(self, sig)
    256 else:
    257     targetoptions = self.nb_func.targetoptions
--> 258 cres, args, return_type = _compile_element_wise_function(
    259     self.nb_func, targetoptions, sig)
    260 sig = self._finalize_signature(cres, args, return_type)
    261 self._sigs.append(sig)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/np/ufunc/ufuncbuilder.py:176, in _compile_element_wise_function(nb_func, targetoptions, sig)
    173 def _compile_element_wise_function(nb_func, targetoptions, sig):
    174     # Do compilation
    175     # Return CompileResult to test
--> 176     cres = nb_func.compile(sig, **targetoptions)
    177     args, return_type = sigutils.normalize_signature(sig)
    178     return cres, args, return_type

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/np/ufunc/ufuncbuilder.py:124, in UFuncDispatcher.compile(self, sig, locals, **targetoptions)
    119 # Disable loop lifting
    120 # The feature requires a real
    121 #  python function
    122 flags.enable_looplift = False
--> 124 return self._compile_core(sig, flags, locals)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/np/ufunc/ufuncbuilder.py:157, in UFuncDispatcher._compile_core(self, sig, flags, locals)
    155 # Compile
    156 args, return_type = sigutils.normalize_signature(sig)
--> 157 cres = compiler.compile_extra(typingctx, targetctx,
    158                               self.py_func, args=args,
    159                               return_type=return_type,
    160                               flags=flags, locals=locals)
    162 # cache lookup failed before so safe to save
    163 self.cache.save_overload(sig, cres)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler.py:751, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    727 """Compiler entry point
    728 
    729 Parameter
   (...)
    747     compiler pipeline
    748 """
    749 pipeline = pipeline_class(typingctx, targetctx, library,
    750                           args, return_type, flags, locals)
--> 751 return pipeline.compile_extra(func)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler.py:445, in CompilerBase.compile_extra(self, func)
    443 self.state.lifted = ()
    444 self.state.lifted_from = None
--> 445 return self._compile_bytecode()

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler.py:513, in CompilerBase._compile_bytecode(self)
    509 """
    510 Populate and run pipeline for bytecode input
    511 """
    512 assert self.state.func_ir is None
--> 513 return self._compile_core()

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler.py:479, in CompilerBase._compile_core(self)
    477 res = None
    478 try:
--> 479     pm.run(self.state)
    480     if self.state.cr is not None:
    481         break

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typed_passes.py:112, in BaseTypeInference.run_pass(self, state)
    106 """
    107 Type inference and legalization
    108 """
    109 with fallback_context(state, 'Function "%s" failed type inference'
    110                       % (state.func_id.func_name,)):
    111     # Type inference
--> 112     typemap, return_type, calltypes, errs = type_inference_stage(
    113         state.typingctx,
    114         state.targetctx,
    115         state.func_ir,
    116         state.args,
    117         state.return_type,
    118         state.locals,
    119         raise_errors=self._raise_errors)
    120     state.typemap = typemap
    121     # save errors in case of partial typing

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typed_passes.py:93, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
     91     infer.build_constraint()
     92     # return errors in case of partial typing
---> 93     errs = infer.propagate(raise_errors=raise_errors)
     94     typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
     96 return _TypingResults(typemap, restype, calltypes, errs)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typeinfer.py:1083, in TypeInferer.propagate(self, raise_errors)
   1080 oldtoken = newtoken
   1081 # Errors can appear when the type set is incomplete; only
   1082 # raise them when there is no progress anymore.
-> 1083 errors = self.constraints.propagate(self)
   1084 newtoken = self.get_state_token()
   1085 self.debug.propagate_finished()

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typeinfer.py:160, in ConstraintNetwork.propagate(self, typeinfer)
    157 with typeinfer.warnings.catch_warnings(filename=loc.filename,
    158                                        lineno=loc.line):
    159     try:
--> 160         constraint(typeinfer)
    161     except ForceLiteralArg as e:
    162         errors.append(e)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typeinfer.py:583, in CallConstraint.__call__(self, typeinfer)
    581     fnty = typevars[self.func].getone()
    582 with new_error_context("resolving callee type: {0}", fnty):
--> 583     self.resolve(typeinfer, typevars, fnty)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typeinfer.py:606, in CallConstraint.resolve(self, typeinfer, typevars, fnty)
    604     fnty = fnty.instance_type
    605 try:
--> 606     sig = typeinfer.resolve_call(fnty, pos_args, kw_args)
    607 except ForceLiteralArg as e:
    608     # Adjust for bound methods
    609     folding_args = ((fnty.this,) + tuple(self.args)
    610                     if isinstance(fnty, types.BoundFunction)
    611                     else self.args)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typeinfer.py:1577, in TypeInferer.resolve_call(self, fnty, pos_args, kw_args)
   1574     return sig
   1575 else:
   1576     # Normal non-recursive call
-> 1577     return self.context.resolve_function_type(fnty, pos_args, kw_args)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typing/context.py:196, in BaseContext.resolve_function_type(self, func, args, kws)
    194 # Prefer user definition first
    195 try:
--> 196     res = self._resolve_user_function_type(func, args, kws)
    197 except errors.TypingError as e:
    198     # Capture any typing error
    199     last_exception = e

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typing/context.py:248, in BaseContext._resolve_user_function_type(self, func, args, kws, literals)
    244         return self.resolve_function_type(func_type, args, kws)
    246 if isinstance(func, types.Callable):
    247     # XXX fold this into the __call__ attribute logic?
--> 248     return func.get_call_type(self, args, kws)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/types/functions.py:308, in BaseFunction.get_call_type(self, context, args, kws)
    305         nolitargs = tuple([_unlit_non_poison(a) for a in args])
    306         nolitkws = {k: _unlit_non_poison(v)
    307                     for k, v in kws.items()}
--> 308         sig = temp.apply(nolitargs, nolitkws)
    309 except Exception as e:
    310     if (utils.use_new_style_errors() and not
    311             isinstance(e, errors.NumbaError)):

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typing/templates.py:351, in AbstractTemplate.apply(self, args, kws)
    349 def apply(self, args, kws):
    350     generic = getattr(self, "generic")
--> 351     sig = generic(args, kws)
    352     # Enforce that *generic()* must return None or Signature
    353     if sig is not None:

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typing/templates.py:614, in _OverloadFunctionTemplate.generic(self, args, kws)
    608 """
    609 Type the overloaded function by compiling the appropriate
    610 implementation for the given args.
    611 """
    612 from numba.core.typed_passes import PreLowerStripPhis
--> 614 disp, new_args = self._get_impl(args, kws)
    615 if disp is None:
    616     return

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typing/templates.py:713, in _OverloadFunctionTemplate._get_impl(self, args, kws)
    709 except KeyError:
    710     # pass and try outside the scope so as to not have KeyError with a
    711     # nested addition error in the case the _build_impl fails
    712     pass
--> 713 impl, args = self._build_impl(cache_key, args, kws)
    714 return impl, args

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/typing/templates.py:817, in _OverloadFunctionTemplate._build_impl(self, cache_key, args, kws)
    815 # Make sure that the implementation can be fully compiled
    816 disp_type = types.Dispatcher(disp)
--> 817 disp_type.get_call_type(self.context, args, kws)
    818 if cache_key is not None:
    819     self._impl_cache[cache_key] = disp, args

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/types/functions.py:541, in Dispatcher.get_call_type(self, context, args, kws)
    534 def get_call_type(self, context, args, kws):
    535     """
    536     Resolve a call to this dispatcher using the given argument types.
    537     A signature returned and it is ensured that a compiled specialization
    538     is available for it.
    539     """
    540     template, pysig, args, kws = \
--> 541         self.dispatcher.get_call_template(args, kws)
    542     sig = template(context).apply(args, kws)
    543     if sig:

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/dispatcher.py:363, in _DispatcherBase.get_call_template(self, args, kws)
    361 # Ensure an overload is available
    362 if self._can_compile:
--> 363     self.compile(tuple(args))
    365 # Create function type for typing
    366 func_name = self.py_func.__name__

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/dispatcher.py:957, in Dispatcher.compile(self, sig)
    955 with ev.trigger_event("numba:compile", data=ev_details):
    956     try:
--> 957         cres = self._compiler.compile(args, return_type)
    958     except errors.ForceLiteralArg as e:
    959         def folded(args, kws):

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/dispatcher.py:125, in _FunctionCompiler.compile(self, args, return_type)
    124 def compile(self, args, return_type):
--> 125     status, retval = self._compile_cached(args, return_type)
    126     if status:
    127         return retval

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/dispatcher.py:139, in _FunctionCompiler._compile_cached(self, args, return_type)
    136     pass
    138 try:
--> 139     retval = self._compile_core(args, return_type)
    140 except errors.TypingError as e:
    141     self._failed_cache[key] = e

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/dispatcher.py:152, in _FunctionCompiler._compile_core(self, args, return_type)
    149 flags = self._customize_flags(flags)
    151 impl = self._get_implementation(args, {})
--> 152 cres = compiler.compile_extra(self.targetdescr.typing_context,
    153                               self.targetdescr.target_context,
    154                               impl,
    155                               args=args, return_type=return_type,
    156                               flags=flags, locals=self.locals,
    157                               pipeline_class=self.pipeline_class)
    158 # Check typing error if object mode is used
    159 if cres.typing_error is not None and not flags.enable_pyobject:

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler.py:751, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    727 """Compiler entry point
    728 
    729 Parameter
   (...)
    747     compiler pipeline
    748 """
    749 pipeline = pipeline_class(typingctx, targetctx, library,
    750                           args, return_type, flags, locals)
--> 751 return pipeline.compile_extra(func)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler.py:445, in CompilerBase.compile_extra(self, func)
    443 self.state.lifted = ()
    444 self.state.lifted_from = None
--> 445 return self._compile_bytecode()

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler.py:513, in CompilerBase._compile_bytecode(self)
    509 """
    510 Populate and run pipeline for bytecode input
    511 """
    512 assert self.state.func_ir is None
--> 513 return self._compile_core()

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler.py:479, in CompilerBase._compile_core(self)
    477 res = None
    478 try:
--> 479     pm.run(self.state)
    480     if self.state.cr is not None:
    481         break

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/untyped_passes.py:86, in TranslateByteCode.run_pass(self, state)
     84 bc = state['bc']
     85 interp = interpreter.Interpreter(func_id)
---> 86 func_ir = interp.interpret(bc)
     87 state["func_ir"] = func_ir
     88 return True

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/interpreter.py:1372, in Interpreter.interpret(self, bytecode)
   1369 self.scopes.append(global_scope)
   1371 flow = Flow(bytecode)
-> 1372 flow.run()
   1373 self.dfa = AdaptDFA(flow)
   1374 self.cfa = AdaptCFA(flow)

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/byteflow.py:124, in Flow.run(self)
    122 # Loop over the state until it is terminated.
    123 while True:
--> 124     runner.dispatch(state)
    125     # Terminated?
    126     if state.has_terminated():

File ~/miniconda3/envs/testing-env/lib/python3.10/site-packages/numba/core/byteflow.py:351, in TraceRunner.dispatch(self, state)
    349 if inst.opname != "CACHE":
    350     _logger.debug("dispatch pc=%s, inst=%s", state._pc, inst)
--> 351     _logger.debug("stack %s", state._stack)
    352 fn = getattr(self, "op_{}".format(inst.opname), None)
    353 if fn is not None:

KeyboardInterrupt: 
nest_tree

Once the nesting tree is defined, it needs to be converted to operating arrays, using the as_arrays method (available in larch 5.7.1 and later). Since we note estimating a nested logit model and just applying one, we can give the parameter values as a dictionary instead of a larch.Model to link against.

nesting = nest_tree.as_arrays(
    trim=True, parameter_dict={"coef_nest_motor": 0.5, "coef_nest_root": 1.0}
)

This dictionary of arrays can be passed in to the logit_draws function to compile a nested logit model intead of a simple MNL.

%time choices_nl, choice_probs_nl = flow.logit_draws(b, draws, nesting=nesting)
%time choices2_nl, choice2_probs_nl = flow.logit_draws(b, draws, source=tree_2, nesting=nesting)

For nested logit models, computing just the logsums is faster than generating probabilities (and making choices) so the logsums=1 argument allows you to short-circuit the computations if you only want the logsums.

flow.logit_draws(b, draws, source=tree_2, nesting=nesting, logsums=1)

Note that for consistency, the choices, probabilities of choices, and pick count arrays are still returned as the first three elements of the returned tuple, but they’re all zero-size empty arrays.

To get both the logsums and the choices, set logsums=2.

flow.logit_draws(b, draws, source=tree_2, nesting=nesting, logsums=2)

Batch Simulation#

Suppose we want to compute logsums not just for one destination, but for many destinations. We can construct a Dataset with two dimensions to use at the top of our DataTree, one for the tours and one for the candidate destinations.

tour_by_dest = tree.subspaces["tour"]
tour_by_dest = tour_by_dest.assign_coords(
    {"CAND_DEST": xr.DataArray(np.arange(25), dims="CAND_DEST")}
)
tour_by_dest

Then we can create a very similar DataTree as above, using this two dimension root Dataset, but we will point to our destination zones from the new tour dimension. and then create a flow from that.

wide_tree = sh.DataTree(
    tour=tour_by_dest,
    person=persons,
    hh=households,
    odt_skims=skims,
    dot_skims=skims,
    relationships=(
        "tour.PERID @ person.PERID",
        "person.household_id @ hh.HHID",
        "hh.TAZ @ odt_skims.otaz",
        "tour.CAND_DEST -> odt_skims.dtaz",
        "tour.out_time_period @ odt_skims.time_period",
        "tour.CAND_DEST -> dot_skims.otaz",
        "hh.TAZ @ dot_skims.dtaz",
        "tour.in_time_period @ dot_skims.time_period",
    ),
    extra_vars={
        "shortwait": 3,
        "one": 1,
    },
    aux_vars={
        "short_i_wait_mult": 0.75,
        "income_breakpoints": income_breakpoints,
    },
    dim_order=("TOURIDX", "CAND_DEST"),
)
wide_flow = wide_tree.setup_flow(spec.Expression)
wide_logsums = wide_flow.logit_draws(b, logsums=1, compile_watch="simple")[-1]
%time wide_logsums = wide_flow.logit_draws(b, logsums=1, compile_watch="simple")[-1]
wide_logsums