Using PlotCollection objects#

This tutorial covers handling PlotCollection: what are its main attributes and methods, how can it be modified… and won’t focus on PlotCollection creation.

Consequently, this should not be the first time you are hearing about PlotCollection. If it were, we recommend first going over either one of the following two pages:

PlotCollection attributes#

viz: organized storage of plotting backend objects#

The .viz attribute contains most of the elements that comprise the visualization itself: the chart, plots and artists.

“most of” because while the chart and plot elements are created directly by methods of PlotCollection like grid or wrap, artists are created by external functions executed through PlotCollection as many times as needed on the indicated plots, and some of these functions might not return an object from the plotting backend library to store.

from arviz_base import load_arviz_data
idata = load_arviz_data("rugby")
from arviz_plots import plot_dist, plot_forest, plot_trace_dist, style
style.use("arviz-clean")

ArviZ plotting functions aim to store as many artists as possible, this makes all artists available to users and to allow further customization after the function has been called. Let’s see what are the contents of the PlotCollection returned by plot_dist:

pc = plot_dist(idata, var_names=["home", "intercept", "atts", "defs"])
../_images/7610b8938f717caf0f380d5536966df5a5f1b1a069a29661e368f365d386ce8d.png
pc.viz
<xarray.DatasetView> Size: 8B
Dimensions:  ()
Data variables:
    chart    object 8B Figure(3450x1500)

As you can see by inspecting the HTML interactive view right above, the .viz attribute is a DataTree with 4 groups, as many groups as data variables. With each group having the following elements:

  • plot: the backend objects that correspond to the plot elements.

  • row and col: integer indicators of the row and colum each plot occupies within the chart

  • kde, credible_interval, point_estimate, point_estimate_text and title: the artists corresponding respectively to: the KDE line (blue line), credible interval line (gray horizontal line), the point estimate dot (gray circle), the point estimate annotation (gray text over the point estimate) and the title (in bolded black font over each plot).

These artist variables storing backend objects can have different shapes between them and the same artist variable can have different shapes between groups/data variables.

Moreover, there is a global chart variable which is always a scalar.

Important

The structure of the .viz attribute is backend agnostic, but its contents are backend dependent.

Here we have generated the plot with matplotlib so the objects stored are matplotlib objects like Figure, Axes, Line2D or Text.

However, if we generate the plot with bokeh (like below) the objects stored will be bokeh objects like Column, Figure, GlyphRenderer or Title.

pc = plot_dist(idata, backend="bokeh")
pc.viz
<xarray.DatasetView> Size: 8B
Dimensions:  ()
Data variables:
    chart    object 8B Column(id='2153', ...)

If instead we inspect the PlotCollection returned by plot_forest we’ll see there are different artists stored, and in that case too, as all variables are in the same plot as they are differenced by their y coordinate the plot, row and col variables are now global, as they are shared by all variables. Moreover, we also now have different shapes for different artists within the same variable as well as different shapes for the same artist among different variables.

pc = plot_forest(idata, var_names=["home", "atts", "defs"])
../_images/8bc6883bd3ee8a955e03998e5c471a33cb0cc3410075d7f41df0695dd2bbfe8a.png
pc.viz
<xarray.DatasetView> Size: 104B
Dimensions:  (column: 2)
Coordinates:
  * column   (column) <U6 48B 'labels' 'forest'
Data variables:
    chart    object 8B Figure(3450x1500)
    plot     (column) object 16B Axes(0.0387322,0.0657229;0.231268x0.925943) ...
    row      (column) int64 16B 0 0
    col      (column) int64 16B 0 1

aes: mapping of aesthetic keys to values and storage all at once#

The other main attribute of PlotCollection is .aes. It is also a DataTree and it has a similar structure, with data variables being groups, but instead of storing plotted objects it stores aesthetic mapping as key-value pairs. This allows us to check what properties are being used that are specific to the different visual elements depending on the coordinate values they represent, to access them for further plotting, or even in more advanced cases, to manually modify some of them before calling (more) plotting functions.

pc = plot_trace_dist(idata, var_names=["home", "intercept", "atts", "defs"])
../_images/4e5dea5dc057856552b983efc9c16be26b6c25239b77465e6b7f0b8d40da84fc.png

In case it wasn’t completely clear from the plot (especially without legends that we still haven’t seen), inspecting the aes attribute we can see that the linestyle depends on the coordinate value of the chain dimension, and the color depends on both the data variable and the team dimension.

pc.aes
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*

There is also an extra aesthetic called overlay whose value is ignored but being present ensures we’ll loop over the right dimensions and draw the expected lines. This is helpful to plot multiple subsets all with the same visual properties, which is the default behaviour in plot_ppc or to ensure the plot behaves as expected even if we ignore some of the default aesthetic mappings like we do in this example.

Customizing your PlotCollection#

Modify specific visual elements#

In order to use global keyword arguments, you can pass them directly if using map or use plot_kwargs if using batteries-included functions. In some cases however we might want more control. One such example could be highlighting the variables that correspond to the national team of Scotland:

pc = plot_dist(idata, var_names=["home", "intercept", "atts", "defs"])
atts_scotland_kde = pc.viz["atts"]["kde"].sel(team="Scotland").item()
# atts_scotland_kde is now the Line2D object that
# corresponds to the kde line of the coordinate Scotland of variable atts
atts_scotland_kde.set(linewidth=3, color="lime")
pc.viz["defs"]["kde"].sel(team="Scotland").item().set(linewidth=3, color="lime");
../_images/df875a09b5ec3f4482a15d93169b63f88c9505d07d9eb6209e14c05f3afb19bb.png

Similarly we can also modify plot properties, for example, add a grid to the intercept one:

pc = plot_dist(idata, var_names=["home", "intercept", "atts", "defs"])
pc.viz["intercept"]["plot"].item().grid(True)
../_images/e4c46e100796b0e287a506a395c1df3d9510a91a30e72d85db3d6e34b9a6b87e.png

We can inspect and modify any of the stored elements:

from bokeh.plotting import output_notebook
output_notebook()
Loading BokehJS ...
pc = plot_dist(
    idata,
    var_names=["home", "atts", "defs"],
    backend="bokeh",
    # make plot smaller
    pc_kwargs={"plot_grid_kws": {"figsize": (1300, 600), "figsize_units": "dots"}},
)
pe_glyph = pc.viz["atts"]["point_estimate"].sel(team="Italy").item().glyph
pe_glyph.fill_color = "red"
pe_glyph.size = 20
pc.show()

We can also use the row and col indexes to select elements based on their position in the plot grid, not using their coordinates.

Note

Selection with row and column is a bit more convoluted that it might need to be, but this also serves to illustrate an important issue. Some operations on the DataTree/Dataset/DataArray objects will trigger copies, which don’t play well with neither bokeh nor matplotlib.

Here for example, attempting to use .where(condition, drop=True) which would make things more direct will trigger a copy and because of that an error on the plotting backend side.

pc = plot_dist(
    idata,
    var_names=["home", "atts", "defs"],
    backend="bokeh",
    # make plot smaller
    pc_kwargs={"plot_grid_kws": {"figsize": (1300, 600), "figsize_units": "dots"}},
)

defs_viz = pc.viz["defs"]
team_at_row2_col1 = defs_viz["plot"].where(
    (defs_viz["row"] == 2) & (defs_viz["col"] == 1)
).idxmax("team").item()
kde_glyph = defs_viz["kde"].sel(team=team_at_row2_col1).item().glyph
kde_glyph.line_color = "lime"
kde_glyph.line_width = 4
pc.show()

Extend the PlotCollection#

Instead of modifying existing visual elements, we might instead want to add more elements to the plots. If we want to add something to a specific plot, the procedure is basically the same as above with the only difference of calling a plotting function instead of modifying properties of the existing elements.

For example, let’s plot a vertical reference line to the defs of the France national team:

pc = plot_dist(idata, var_names=["home", "atts", "defs"])
pc.viz["defs"]["plot"].sel(team="France").item().axvline(0, color="red");
../_images/73218810b7d43b5579b6c3f44df589120cd490926d0a296aedb220d0b617c591.png

If we instead want to apply it to all plotting functions, we can use map:

# to be able to use map, callables must accept 3 positional arguments
# a DataArray, the plotting target and the backend
def axvline(da, target, backend, **kwargs):
    return target.axvline(0, **kwargs)

pc = plot_dist(idata, var_names=["home", "atts", "defs"])
pc.map(axvline, color="red")
../_images/fc901c522d9ffd297cd2c174ddf678b5d70521fa8cbc4641567ecd5c53f181a8.png

See also

The map method is one of the main building blocks provided by PlotCollection. The Create your own chart with PlotCollection page covers the use of map more extensively.

Legends#

PlotCollection also provides a method to automatically generate legends for the plots.

Warning

The API of the add_legend method is still quite experimental.

For properties that are shared for all variables, generating the legend is relatively straightforward as it is unique and using the dimension name as title is a sensible choice.

pc = plot_trace_dist(idata, var_names=["home", "intercept", "atts", "defs"])
pc.add_legend("chain");
../_images/4a480efcc6db44b804aa7bd790ccc42211f7887cce2f679a7def878a05abbe6d.png

It is also possible however to have properties that depend on both the data variable and on one (or even multiple) dimensions. In such cases, like in the example below, the legend becomes dependent on the data variable of interest, and we should probably generate two legends. Or none at all, note that here the color is mostly used to help us visually cluster the multiple KDEs the correspond the same coordinate values, not so much as a way to encode information.

pc = plot_trace_dist(
    idata,
    var_names=["home", "intercept", "atts", "defs"],
)
pc.add_legend("team", var_name="atts", title="team (atts)", loc="outside lower right", fontsize=10, ncols=3)
pc.add_legend("team", var_name="defs", title="team (defs)", loc="outside lower left", fontsize=10, ncols=3);
../_images/8ae52b3811668a97e660a53dc7e30f86cdfbb61b01af2c6fe2640256df72c632.png