Introduction to batteries-included plots in arviz_plots#

In this tutorial we’ll use plot_dist to show how to interact with the new aesthetics mapping and facetting powered by PlotCollection, effectively providing a first introduction to the use of plot_... plotting functions in arviz_plots.

plot_dist is a “batteries-included” function that plots 1D marginal distributions in the style of John K. Kruschke’s Doing Bayesian Data Analysis book. It used to be called plot_posterior.

import numpy as np
import arviz_plots as azp
from arviz_base import load_arviz_data
azp.style.use("arviz-clean")
schools = load_arviz_data("centered_eight")

Default behaviour#

plot_... functions only have one required argument, the data to be plotted as DataTree (InferenceData representation). These functions have a set of defaults regarding facetting, artist layout and labeling which should hopefully generate sensible figures independently of the data.

In plot_dist’s case, the default is to generate a grid of subplots, with as many subplots as variables and individual coordinate value combinations (after reducing the sampled dimensions). The schools dataset has 3 variables: mu and tau are scalar parameters of the model, and theta has an extra dimension school with 8 coordinate values. Thus, a grid with \(1+1+8=10\) subplots is generated:

azp.plot_dist(schools);
../_images/c3e95b2082fe49a4fb340df88aef4ff6e9877ad5d3dc0507e517a5662a510ee3.png

As you can see, each subplot combines 3 quantities derived from the data: the probability density information (here represented with a KDE), the credible interval information (here represented with an equal tail interval) and the point estimate information (here the mean is plotted).

Each plot_... will have different quantities plotted, but the pattern remains the same. There are a handful of top level keyword arguments, related to data selection and general properties of the plot, and then a handful of arguments that take dictionaries that can provide finer control on the function behaviour.

In plot_dist’s case, the top level arguments are:

  • Data selection related: var_names, filter_vars, coords, group and sample_dims

  • General properties: kind to choose how to represent the probability density, point_estimate to choose which point estimate to use, ci_kind and ci_prob to control the credible interval, plot_collection in case you want to provide an existing PlotCollection class, backend to choose the plotting backend and labeller to chose how to label each subplot.

The dictionary arguments are: aes_map, plot_kwargs, stats_kwargs and pc_kwargs. They will be introduced in a section of their own further ahead.

plot_kwargs#

plot_kwargs is a dictionary that dispatches keyword arguments through to the backend plotting functions. Its keys should be graphical elements and its values should be dictionaries that are passed as is to the plotting functions. The docstrings of each plot_... function indicate which are the valid top level keys and where is each dictionary value dispatched to.

As you can see in its docstring, plot_dist has 5 valid keys: kde, credible_interval, point_estimate, point_estimate_text and title.

Modify properties of visual elements in all plots#

We can use it to change the color of the KDE line:

azp.plot_dist(schools, plot_kwargs={"kde": {"color": "orange"}});
../_images/a2082f9ab30c23127e7cce3297e101f5a6e7162f956ab9f7cc10c06ce6fb1d87.png

Or to change the color and linestyle of the credible interval together with the fonts of the point estimate annotation and title:

azp.plot_dist(
    schools,
    var_names=["mu", "tau"],
    plot_kwargs={
        "credible_interval": {"linestyle": "--"},
        "point_estimate_text": {"fontstyle": "italic"},
        "title": {"fontfamily": "Charcoal"}
    }
);
../_images/e28a2b5a2322b2b47162401620668bc55ebef3f2dc475bc41413cb2a07097a2f.png

Remove visual elements from the plot#

plot_kwargs can also be used to remove visual elements from the plot. For example, to keep only the marginal distribution representation:

azp.plot_dist(
    schools, 
    plot_kwargs={
        "credible_interval": False,
        "point_estimate": False,
        "point_estimate_text": False,
    }
);
../_images/0621ab9b0fac6fd8ac009c2ae5e2e68702af83edd1f92defe85b34f097c8fab8.png

stats_kwargs#

stats_kwargs is a dictionary that dispatches keyword arguments through to the statistical computation function.

In plot_dist we can use it for example to control the kde computation and modify the bandwidth selection algorithm:

azp.plot_dist(schools, stats_kwargs={"density": {"bw": "scott"}});
../_images/361a1c1f71361492d4a22764fc866258de6b1ad7fd970503f0a55f2e956acf9d.png

pc_kwargs and aes_map#

pc_kwargs are passed to arviz_plots.PlotCollection.wrap to initialize the PlotCollection object that takes care of facetting and aesthetics mapping, and to generate and manage the chart. With it we can regulate from the figure size or sharing of axis limits, to modifying completely the layout and aesthetics of the generated plot.

Thus, each chart has a set of mappings between dataset properties and graphical properties. For example, we might encode the school information (dataset property) with the color (graphical property). And these mappings are shared between all subplots and between all graphical elements. aes_map regulates which mappings apply to which graphical elements; by default, mappings only apply to the density representation.

Adding aesthetic mappings to a visualization#

We can start by defining an aesthetics mapping:

azp.plot_dist(
    schools,
    pc_kwargs={
        # encode the school information in the color property
        "aes": {"color": ["school"]},
    }
);
../_images/1b2925d0a2afb65cf2be2b0f28200aae90756858de10c6137eef3428cfe859ce.png

Note that mu and tau have a common color, which is different to any of the theta lines. As they don’t have the school dimension, PlotCollection then takes the first element in the aesthetic (in this case “C0”, the first color of the matplotlib color cycle) as neutral element, then generates a mapping excluding that element. The neutral element is therefore reserved to be used when the mapping can’t be applied and only then.

Removing aesthetic mappings from a visualization#

Similarly, we can also use pc_kwargs to remove aesthetic mappings from plots that define them by default. For example, plot_trace_dist by default maps the linestyle to the chain dimension and the color to the variable and all non sample dims together.

azp.plot_trace_dist(
    schools,
    pc_kwargs={"aes": {"linestyle": False}},
);
../_images/08dfb42814ffd0753ca4911df93d3d8a9eff7348d2c6d1b18e23539468b18890.png

Choosing the artists where aesthetic mappings are applied#

We can configure which artists take the defined aesthetic mappings into account with aes_map:

azp.plot_dist(
    schools,
    pc_kwargs={"aes": {"color": ["school"]}},
    # apply the color-school mapping to all graphical elements but the title
    aes_map={
        "kde": ["color"],
        "point_estimate": ["color"],
        "credible_interval": ["color"]
    }
);
../_images/b86e27cde9e719350e9f2a664515aea1edf946fcba1e6063cfdd031e0bf1ba44.png

We can have as many aesthetics mapping as desired, and map all of them, none or a subset of them to the different graphical elements:

azp.plot_dist(
    schools, 
    pc_kwargs={
        "aes": {"color": ["school"], "linestyle": ["chain"]},
        "plot_grid_kws": {"figsize": (12, 7)}
    },
    aes_map={"kde": ["color", "linestyle"], "point_estimate": ["color"]},
);
../_images/2009d6e4631e7b4b0e0b5583c516876167bf7ae2e9dab57935bcc4d50248556c.png

Note that now there is an aesthetic (linestyle) mapped to the chain dimension. Therefore, PlotCollection now loops over the chain dimension in order to enforce the aesthetic mapping, generating now 4 kde lines in each plot.

Legends#

Legends are not automatic, but can be generated by the PlotCollection class which is returned by all plot_... functions:

pc = azp.plot_dist(
    schools, 
    pc_kwargs={
        "aes": {"color": ["school"], "linestyle": ["chain"]},
        "plot_grid_kws": {"figsize": (12, 7)}
    },
    aes_map={"kde": ["color", "linestyle"], "point_estimate": ["color"]},
)
pc.add_legend("school", loc="outside right upper")
pc.add_legend("chain", loc="outside right lower");
../_images/c0670c58b6d972bfb074ba069a047747a9d9ac8c464b7837a6770b905bb7df50.png

Advanced examples#

So far we have called properties in which we encode aesthetics, and used color and linestyle as such properties. But PlotCollection only manages the mapping between a dataset property (in the form of a dimension) and a graphical property with both key and values free to take anything. The only limitation comes later when the mapped properties are passed down to the plotting function, these keys and values must be valid for the plotting function otherwise you’ll get an error.

The functions used by plot_dist (and by other plot_... functions) aim to be somewhat general, so for example, y is a valid key for encoding information. But keep in mind that if you want to generate plots significantly different from the default layout of plot_dist you’ll need to follow the steps in Create your own chart with PlotCollection and use PlotCollection manually.

azp.plot_dist(
    schools, 
    pc_kwargs={
        # stop creating one subplot per variable *and* coordinate value,
        # generate only one per variable, in this case 3 subplots
        "cols": ["__variable__"],
        # encode the school informatin in both color and y properties
        "aes": {"color": ["school"], "y": ["school"]},
        "y": np.linspace(0, 0.06, 9),
    },
    aes_map={
        "kde": ["color"],
        "point_estimate": ["color", "y"],
        "credible_interval": ["y"]
    },
    plot_kwargs={"point_estimate_text": {"bbox": {"boxstyle": "round", "fc": (1, 1, 1, 0.7)}}},
);
../_images/ceb323e3b8181170db687d7fa8310dbb1dbdb24c522404cbabcae61a6977d769.png

We can have a range of variables with different shapes and dimensions, and as long as we are careful with the facetting and mapping arguments to not have them contradict each other, the underlying PlotCollection used by plot_dict can combine all the different variables into a single plot, mapping the aesthetics provided only when relevant:

from arviz_base.datasets import REMOTE_DATASETS, RemoteFileMetadata
# TODO: remove this monkeypatching once the arviz_example_data repo has been updated
REMOTE_DATASETS.update({
    "rugby_field": RemoteFileMetadata(
        name="rugby_field",
        filename="rugby_field.nc",
        url="https://figshare.com/ndownloader/files/44667112",
        checksum="53a99da7ac40d82cd01bb0b089263b9633ee016f975700e941b4c6ea289a1fb0",
        description="""Variant of the rugby model."""
    )
})
rugby = load_arviz_data("rugby_field")

Here for example we have 4 variables: a two dimensional one, two with 3 dimensions (but different dimensions for their 3rd one) and one with 4 dimensions:

rugby.posterior.ds[["atts_team", "atts", "intercept", "sd_att"]]
<xarray.Dataset> Size: 340kB
Dimensions:    (chain: 4, draw: 500, team: 6, field: 2)
Coordinates:
  * chain      (chain) int64 32B 0 1 2 3
  * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
  * field      (field) <U4 32B 'home' 'away'
  * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
Data variables:
    atts_team  (chain, draw, team) float64 96kB ...
    atts       (chain, draw, field, team) float64 192kB ...
    intercept  (chain, draw, field) float64 32kB ...
    sd_att     (chain, draw) float64 16kB ...
Attributes:
    created_at:                 2024-02-23T20:21:03.016373
    arviz_version:              0.17.0
    inference_library:          pymc
    inference_library_version:  5.10.4+7.g34d2a5d9
    sampling_time:              21.146891355514526
    tuning_steps:               1000
pc = azp.plot_dist(
    rugby,
    var_names=["atts_team", "atts", "intercept", "sd_att"],
    pc_kwargs={
        "cols": ["__variable__", "team"],
        "col_wrap": 6,
        "plot_grid_kws": {"figsize": (10, 6)},
        "aes": {
            "linestyle": ["field"],
            "color": ["team"],
            "marker": ["field"],
            "y": ["field"]
        },
        "y": [0, 0.1, 0.2]
    },
    aes_map={
        "kde": ["color", "linestyle"],
        "point_estimate": ["color", "marker", "y"],
        "point_estimate_text": ["color", "y"],
        "credible_interval": ["y"]
    },
)
pc.add_legend("team")
pc.add_legend("field", loc="outside right center");
../_images/65a44283b9721ef3524fa5e476b5f4cbcba7e3f29701f797a0addd9190b6c25d.png

See also

  • Using PlotCollection objects covers handling of PlotCollection objects to further customize and inspect the plots generated with batteries-included functions.

  • Create your own chart with PlotCollection shows how to create and fill visualizations from scratch using PlotCollection to allow you to generate your own specific plotting functions, or to generate domain specific batteries-included ones