This tutorial covers handling PlotCollection: what are its main attributes and methods, how can it be modified… and
won’t focus on PlotCollection creation.
Consequently, this should not be the first time you are hearing about PlotCollection. If it were, we recommend first going over either one of the following two pages:
Introduction to batteries-included plots in arviz_plots which introduces the “batteries-included” functions. That is, functions that take data following the InferenceData schema and generate a specific type of plot, using an opinionated and pre-defined set of defaults. All these functions return a PlotCollection object.
viz: organized storage of plotting backend objects#
The .viz attribute contains most of the elements that comprise the visualization itself: the chart, plots and artists.
“most of” because while the chart and plot elements are created directly by methods of PlotCollection like grid or
wrap, artists are created by external functions executed through PlotCollection as many times as needed on the indicated plots,
and some of these functions might not return an object from the plotting backend library to store.
ArviZ plotting functions aim to store as many artists as possible, this makes all artists available to users and to allow further customization after the function has been called.
Let’s see what are the contents of the PlotCollection returned by plot_dist:
array([<matplotlib.lines.Line2D object at 0x7f9b4611c8d0>,
<matplotlib.lines.Line2D object at 0x7f9b4611e050>,
<matplotlib.lines.Line2D object at 0x7f9b4611e550>,
<matplotlib.lines.Line2D object at 0x7f9b4611ec50>,
<matplotlib.lines.Line2D object at 0x7f9b461341d0>,
<matplotlib.lines.Line2D object at 0x7f9b46135610>], dtype=object)
credible_interval
(team)
object
Line2D(_child1) ... Line2D(_child1)
array([<matplotlib.lines.Line2D object at 0x7f9b4614ce90>,
<matplotlib.lines.Line2D object at 0x7f9b4614d990>,
<matplotlib.lines.Line2D object at 0x7f9b4614e3d0>,
<matplotlib.lines.Line2D object at 0x7f9b4614eed0>,
<matplotlib.lines.Line2D object at 0x7f9b4614f510>,
<matplotlib.lines.Line2D object at 0x7f9b4614f750>], dtype=object)
point_estimate
(team)
object
<matplotlib.collections.PathColl...
array([<matplotlib.collections.PathCollection object at 0x7f9b45f15650>,
<matplotlib.collections.PathCollection object at 0x7f9b45f167d0>,
<matplotlib.collections.PathCollection object at 0x7f9b45f17f50>,
<matplotlib.collections.PathCollection object at 0x7f9b47cda710>,
<matplotlib.collections.PathCollection object at 0x7f9b45f22ed0>,
<matplotlib.collections.PathCollection object at 0x7f9b45f30390>],
dtype=object)
array([<matplotlib.lines.Line2D object at 0x7f9b46135b10>,
<matplotlib.lines.Line2D object at 0x7f9b46136bd0>,
<matplotlib.lines.Line2D object at 0x7f9b46136a90>,
<matplotlib.lines.Line2D object at 0x7f9b46137750>,
<matplotlib.lines.Line2D object at 0x7f9b4613c490>,
<matplotlib.lines.Line2D object at 0x7f9b4613dd90>], dtype=object)
credible_interval
(team)
object
Line2D(_child1) ... Line2D(_child1)
array([<matplotlib.lines.Line2D object at 0x7f9b4614c510>,
<matplotlib.lines.Line2D object at 0x7f9b45f00c10>,
<matplotlib.lines.Line2D object at 0x7f9b45f02510>,
<matplotlib.lines.Line2D object at 0x7f9b45f02f50>,
<matplotlib.lines.Line2D object at 0x7f9b45f03950>,
<matplotlib.lines.Line2D object at 0x7f9b45f03cd0>], dtype=object)
point_estimate
(team)
object
<matplotlib.collections.PathColl...
array([<matplotlib.collections.PathCollection object at 0x7f9b45f31c90>,
<matplotlib.collections.PathCollection object at 0x7f9b45f32910>,
<matplotlib.collections.PathCollection object at 0x7f9b45f33b50>,
<matplotlib.collections.PathCollection object at 0x7f9b47ce2590>,
<matplotlib.collections.PathCollection object at 0x7f9b45f3a850>,
<matplotlib.collections.PathCollection object at 0x7f9b45f3be10>],
dtype=object)
array(<Figure size 3450x1500 with 15 Axes>, dtype=object)
As you can see by inspecting the HTML interactive view right above, the .viz attribute is a DataTree with 4 groups, as many groups as data variables.
With each group having the following elements:
plot: the backend objects that correspond to the plot elements.
row and col: integer indicators of the row and colum each plot occupies within the chart
kde, credible_interval, point_estimate, point_estimate_text and title: the artists corresponding respectively to: the KDE line (blue line), credible interval line (gray horizontal line), the point estimate dot (gray circle), the point estimate annotation (gray text over the point estimate) and the title (in bolded black font over each plot).
These artist variables storing backend objects can have different shapes between them and the same artist variable can have different shapes between groups/data variables.
Moreover, there is a global chart variable which is always a scalar.
Important
The structure of the .viz attribute is backend agnostic, but its contents are backend dependent.
Here we have generated the plot with matplotlib so the objects stored are matplotlib objects like Figure, Axes, Line2D or Text.
However, if we generate the plot with bokeh (like below) the objects stored will be bokeh objects like Column, Figure, GlyphRenderer or Title.
If instead we inspect the PlotCollection returned by plot_forest we’ll see there are different artists stored,
and in that case too, as all variables are in the same plot as they are differenced by their y coordinate the plot, row and col variables
are now global, as they are shared by all variables. Moreover, we also now have different shapes for different artists within the same variable
as well as different shapes for the same artist among different variables.
array([<matplotlib.lines.Line2D object at 0x7f9b42a7d290>,
<matplotlib.lines.Line2D object at 0x7f9b40791090>,
<matplotlib.lines.Line2D object at 0x7f9b407b3610>,
<matplotlib.lines.Line2D object at 0x7f9b407b3ad0>], dtype=object)
trunk
(chain)
object
Line2D(_child52) ... Line2D(_chi...
array([<matplotlib.lines.Line2D object at 0x7f9b407c4e10>,
<matplotlib.lines.Line2D object at 0x7f9b407a2a90>,
<matplotlib.lines.Line2D object at 0x7f9b4062ef10>,
<matplotlib.lines.Line2D object at 0x7f9b4062fc10>], dtype=object)
point_estimate
(chain)
object
<matplotlib.collections.PathColl...
array([<matplotlib.collections.PathCollection object at 0x7f9b40af9610>,
<matplotlib.collections.PathCollection object at 0x7f9b408bf3d0>,
<matplotlib.collections.PathCollection object at 0x7f9b406adbd0>,
<matplotlib.collections.PathCollection object at 0x7f9b406aeb90>],
dtype=object)
array([[<matplotlib.lines.Line2D object at 0x7f9b407b3590>,
<matplotlib.lines.Line2D object at 0x7f9b407c6310>,
<matplotlib.lines.Line2D object at 0x7f9b407c7010>,
<matplotlib.lines.Line2D object at 0x7f9b407c7350>,
<matplotlib.lines.Line2D object at 0x7f9b407c7e90>,
<matplotlib.lines.Line2D object at 0x7f9b407cd350>],
[<matplotlib.lines.Line2D object at 0x7f9b407ce110>,
<matplotlib.lines.Line2D object at 0x7f9b407cedd0>,
<matplotlib.lines.Line2D object at 0x7f9b407cf150>,
<matplotlib.lines.Line2D object at 0x7f9b407cdfd0>,
<matplotlib.lines.Line2D object at 0x7f9b405e1450>,
<matplotlib.lines.Line2D object at 0x7f9b405e2110>],
[<matplotlib.lines.Line2D object at 0x7f9b405e2d90>,
<matplotlib.lines.Line2D object at 0x7f9b405e3dd0>,
<matplotlib.lines.Line2D object at 0x7f9b405e3f10>,
<matplotlib.lines.Line2D object at 0x7f9b405e9050>,
<matplotlib.lines.Line2D object at 0x7f9b405e9ed0>,
<matplotlib.lines.Line2D object at 0x7f9b405eab90>],
[<matplotlib.lines.Line2D object at 0x7f9b405eaf50>,
<matplotlib.lines.Line2D object at 0x7f9b405eb910>,
<matplotlib.lines.Line2D object at 0x7f9b405f4ad0>,
<matplotlib.lines.Line2D object at 0x7f9b405f58d0>,
<matplotlib.lines.Line2D object at 0x7f9b405f66d0>,
<matplotlib.lines.Line2D object at 0x7f9b405f7590>]], dtype=object)
trunk
(chain, team)
object
Line2D(_child56) ... Line2D(_chi...
array([[<matplotlib.lines.Line2D object at 0x7f9b4062fc90>,
<matplotlib.lines.Line2D object at 0x7f9b406359d0>,
<matplotlib.lines.Line2D object at 0x7f9b40636cd0>,
<matplotlib.lines.Line2D object at 0x7f9b40637a90>,
<matplotlib.lines.Line2D object at 0x7f9b406375d0>,
<matplotlib.lines.Line2D object at 0x7f9b40627b90>],
[<matplotlib.lines.Line2D object at 0x7f9b406457d0>,
<matplotlib.lines.Line2D object at 0x7f9b40646350>,
<matplotlib.lines.Line2D object at 0x7f9b40647550>,
<matplotlib.lines.Line2D object at 0x7f9b40645950>,
<matplotlib.lines.Line2D object at 0x7f9b40646210>,
<matplotlib.lines.Line2D object at 0x7f9b4064d750>],
[<matplotlib.lines.Line2D object at 0x7f9b4064e390>,
<matplotlib.lines.Line2D object at 0x7f9b4064eed0>,
<matplotlib.lines.Line2D object at 0x7f9b4064d110>,
<matplotlib.lines.Line2D object at 0x7f9b4064ff90>,
<matplotlib.lines.Line2D object at 0x7f9b4064e890>,
<matplotlib.lines.Line2D object at 0x7f9b40659950>],
[<matplotlib.lines.Line2D object at 0x7f9b4065aa10>,
<matplotlib.lines.Line2D object at 0x7f9b4065b050>,
<matplotlib.lines.Line2D object at 0x7f9b4065b910>,
<matplotlib.lines.Line2D object at 0x7f9b406fbed0>,
<matplotlib.lines.Line2D object at 0x7f9b40669290>,
<matplotlib.lines.Line2D object at 0x7f9b4066a110>]], dtype=object)
point_estimate
(chain, team)
object
<matplotlib.collections.PathColl...
array([[<matplotlib.collections.PathCollection object at 0x7f9b406b4810>,
<matplotlib.collections.PathCollection object at 0x7f9b406b5cd0>,
<matplotlib.collections.PathCollection object at 0x7f9b406b7390>,
<matplotlib.collections.PathCollection object at 0x7f9b406c48d0>,
<matplotlib.collections.PathCollection object at 0x7f9b406b7d90>,
<matplotlib.collections.PathCollection object at 0x7f9b406b71d0>],
[<matplotlib.collections.PathCollection object at 0x7f9b406c8950>,
<matplotlib.collections.PathCollection object at 0x7f9b406c9a90>,
<matplotlib.collections.PathCollection object at 0x7f9b408db450>,
<matplotlib.collections.PathCollection object at 0x7f9b404dc850>,
<matplotlib.collections.PathCollection object at 0x7f9b406c5a90>,
<matplotlib.collections.PathCollection object at 0x7f9b404defd0>],
[<matplotlib.collections.PathCollection object at 0x7f9b404e44d0>,
<matplotlib.collections.PathCollection object at 0x7f9b404dfd90>,
<matplotlib.collections.PathCollection object at 0x7f9b404e6750>,
<matplotlib.collections.PathCollection object at 0x7f9b404e7bd0>,
<matplotlib.collections.PathCollection object at 0x7f9b406b5310>,
<matplotlib.collections.PathCollection object at 0x7f9b404e68d0>],
[<matplotlib.collections.PathCollection object at 0x7f9b404f7d90>,
<matplotlib.collections.PathCollection object at 0x7f9b404e6890>,
<matplotlib.collections.PathCollection object at 0x7f9b404fd010>,
<matplotlib.collections.PathCollection object at 0x7f9b40b07150>,
<matplotlib.collections.PathCollection object at 0x7f9b40508650>,
<matplotlib.collections.PathCollection object at 0x7f9b40509d10>]],
dtype=object)
array([[<matplotlib.lines.Line2D object at 0x7f9b407c5210>,
<matplotlib.lines.Line2D object at 0x7f9b407c5390>,
<matplotlib.lines.Line2D object at 0x7f9b40601690>,
<matplotlib.lines.Line2D object at 0x7f9b40602910>,
<matplotlib.lines.Line2D object at 0x7f9b40603510>,
<matplotlib.lines.Line2D object at 0x7f9b40603b90>],
[<matplotlib.lines.Line2D object at 0x7f9b40603410>,
<matplotlib.lines.Line2D object at 0x7f9b4060d910>,
<matplotlib.lines.Line2D object at 0x7f9b4060e090>,
<matplotlib.lines.Line2D object at 0x7f9b4060f350>,
<matplotlib.lines.Line2D object at 0x7f9b4060d2d0>,
<matplotlib.lines.Line2D object at 0x7f9b4060df50>],
[<matplotlib.lines.Line2D object at 0x7f9b40614f90>,
<matplotlib.lines.Line2D object at 0x7f9b40615710>,
<matplotlib.lines.Line2D object at 0x7f9b40616350>,
<matplotlib.lines.Line2D object at 0x7f9b406175d0>,
<matplotlib.lines.Line2D object at 0x7f9b40617890>,
<matplotlib.lines.Line2D object at 0x7f9b40624bd0>],
[<matplotlib.lines.Line2D object at 0x7f9b40625410>,
<matplotlib.lines.Line2D object at 0x7f9b40625d50>,
<matplotlib.lines.Line2D object at 0x7f9b40626910>,
<matplotlib.lines.Line2D object at 0x7f9b40627d90>,
<matplotlib.lines.Line2D object at 0x7f9b4062c890>,
<matplotlib.lines.Line2D object at 0x7f9b4062dd50>]], dtype=object)
trunk
(chain, team)
object
Line2D(_child80) ... Line2D(_chi...
array([[<matplotlib.lines.Line2D object at 0x7f9b40634f10>,
<matplotlib.lines.Line2D object at 0x7f9b4066ba50>,
<matplotlib.lines.Line2D object at 0x7f9b4066bc10>,
<matplotlib.lines.Line2D object at 0x7f9b40675610>,
<matplotlib.lines.Line2D object at 0x7f9b40675d90>,
<matplotlib.lines.Line2D object at 0x7f9b40674610>],
[<matplotlib.lines.Line2D object at 0x7f9b40675010>,
<matplotlib.lines.Line2D object at 0x7f9b40677bd0>,
<matplotlib.lines.Line2D object at 0x7f9b4067d490>,
<matplotlib.lines.Line2D object at 0x7f9b4067d910>,
<matplotlib.lines.Line2D object at 0x7f9b4067e490>,
<matplotlib.lines.Line2D object at 0x7f9b4067f0d0>],
[<matplotlib.lines.Line2D object at 0x7f9b4067f390>,
<matplotlib.lines.Line2D object at 0x7f9b4067ff90>,
<matplotlib.lines.Line2D object at 0x7f9b4068d890>,
<matplotlib.lines.Line2D object at 0x7f9b4068df10>,
<matplotlib.lines.Line2D object at 0x7f9b4068f050>,
<matplotlib.lines.Line2D object at 0x7f9b4068f8d0>],
[<matplotlib.lines.Line2D object at 0x7f9b4068fbd0>,
<matplotlib.lines.Line2D object at 0x7f9b42a56190>,
<matplotlib.lines.Line2D object at 0x7f9b406990d0>,
<matplotlib.lines.Line2D object at 0x7f9b4069a650>,
<matplotlib.lines.Line2D object at 0x7f9b4069b010>,
<matplotlib.lines.Line2D object at 0x7f9b4069b2d0>]], dtype=object)
point_estimate
(chain, team)
object
<matplotlib.collections.PathColl...
array([[<matplotlib.collections.PathCollection object at 0x7f9b406caf90>,
<matplotlib.collections.PathCollection object at 0x7f9b40508550>,
<matplotlib.collections.PathCollection object at 0x7f9b405080d0>,
<matplotlib.collections.PathCollection object at 0x7f9b4050acd0>,
<matplotlib.collections.PathCollection object at 0x7f9b404e42d0>,
<matplotlib.collections.PathCollection object at 0x7f9b404fd690>],
[<matplotlib.collections.PathCollection object at 0x7f9b406aeb50>,
<matplotlib.collections.PathCollection object at 0x7f9b404e7e50>,
<matplotlib.collections.PathCollection object at 0x7f9b4052cf10>,
<matplotlib.collections.PathCollection object at 0x7f9b4052dcd0>,
<matplotlib.collections.PathCollection object at 0x7f9b4052f690>,
<matplotlib.collections.PathCollection object at 0x7f9b40538ad0>],
[<matplotlib.collections.PathCollection object at 0x7f9b40afa2d0>,
<matplotlib.collections.PathCollection object at 0x7f9b40513850>,
<matplotlib.collections.PathCollection object at 0x7f9b4052df90>,
<matplotlib.collections.PathCollection object at 0x7f9b4052dc90>,
<matplotlib.collections.PathCollection object at 0x7f9b40546ad0>,
<matplotlib.collections.PathCollection object at 0x7f9b40af1c50>],
[<matplotlib.collections.PathCollection object at 0x7f9b4054ca10>,
<matplotlib.collections.PathCollection object at 0x7f9b40546710>,
<matplotlib.collections.PathCollection object at 0x7f9b4054fc10>,
<matplotlib.collections.PathCollection object at 0x7f9b40554910>,
<matplotlib.collections.PathCollection object at 0x7f9b404fdb10>,
<matplotlib.collections.PathCollection object at 0x7f9b40557010>]],
dtype=object)
column: 2
column
(column)
<U6
'labels' 'forest'
array(['labels', 'forest'], dtype='<U6')
chart
()
object
Figure(3450x1500)
array(<Figure size 3450x1500 with 2 Axes>, dtype=object)
plot
(column)
object
Axes(0.0387322,0.0657229;0.23126...
array([<Axes: >, <Axes: >], dtype=object)
row
(column)
int64
0 0
array([0, 0])
col
(column)
int64
0 1
array([0, 1])
aes: mapping of aesthetic keys to values and storage all at once#
The other main attribute of PlotCollection is .aes. It is also a DataTree and it has a similar structure, with data variables being groups,
but instead of storing plotted objects it stores aesthetic mapping as key-value pairs. This allows us to check what properties are being used
that are specific to the different visual elements depending on the coordinate values they represent, to access them for further plotting, or
even in more advanced cases, to manually modify some of them before calling (more) plotting functions.
In case it wasn’t completely clear from the plot (especially without legends that we still haven’t seen), inspecting the aes attribute we can see
that the linestyle depends on the coordinate value of the chain dimension, and the color depends on both the data variable and the team dimension.
pc.aes
<xarray.DatasetView> Size: 0B
Dimensions: ()
Data variables:
*empty*
There is also an extra aesthetic called overlay whose value is ignored but being present ensures we’ll loop over the right dimensions and draw the expected lines.
This is helpful to plot multiple subsets all with the same visual properties, which is the default behaviour in plot_ppc or to ensure the plot
behaves as expected even if we ignore some of the default aesthetic mappings like we do in this example.
In order to use global keyword arguments, you can pass them directly if using map or use plot_kwargs if using batteries-included functions.
In some cases however we might want more control. One such example could be highlighting the variables that correspond to the national team of Scotland:
pc=plot_dist(idata,var_names=["home","intercept","atts","defs"])atts_scotland_kde=pc.viz["atts"]["kde"].sel(team="Scotland").item()# atts_scotland_kde is now the Line2D object that# corresponds to the kde line of the coordinate Scotland of variable attsatts_scotland_kde.set(linewidth=3,color="lime")pc.viz["defs"]["kde"].sel(team="Scotland").item().set(linewidth=3,color="lime");
Similarly we can also modify plot properties, for example, add a grid to the intercept one:
pc=plot_dist(idata,var_names=["home","atts","defs"],backend="bokeh",# make plot smallerpc_kwargs={"plot_grid_kws":{"figsize":(1300,600),"figsize_units":"dots"}},)pe_glyph=pc.viz["atts"]["point_estimate"].sel(team="Italy").item().glyphpe_glyph.fill_color="red"pe_glyph.size=20pc.show()
We can also use the row and col indexes to select elements based on their position in the plot grid, not using their coordinates.
Note
Selection with row and column is a bit more convoluted that it might need to be, but this also serves to illustrate an important issue.
Some operations on the DataTree/Dataset/DataArray objects will trigger copies, which don’t play well with neither bokeh nor matplotlib.
Here for example, attempting to use .where(condition,drop=True) which would make things more direct will trigger a copy and because of that an error
on the plotting backend side.
pc=plot_dist(idata,var_names=["home","atts","defs"],backend="bokeh",# make plot smallerpc_kwargs={"plot_grid_kws":{"figsize":(1300,600),"figsize_units":"dots"}},)defs_viz=pc.viz["defs"]team_at_row2_col1=defs_viz["plot"].where((defs_viz["row"]==2)&(defs_viz["col"]==1)).idxmax("team").item()kde_glyph=defs_viz["kde"].sel(team=team_at_row2_col1).item().glyphkde_glyph.line_color="lime"kde_glyph.line_width=4pc.show()
Instead of modifying existing visual elements, we might instead want to add more elements to the plots.
If we want to add something to a specific plot, the procedure is basically the same as above with the only difference
of calling a plotting function instead of modifying properties of the existing elements.
For example, let’s plot a vertical reference line to the defs of the France national team:
If we instead want to apply it to all plotting functions, we can use map:
# to be able to use map, callables must accept 3 positional arguments# a DataArray, the plotting target and the backenddefaxvline(da,target,backend,**kwargs):returntarget.axvline(0,**kwargs)pc=plot_dist(idata,var_names=["home","atts","defs"])pc.map(axvline,color="red")
See also
The map method is one of the main building blocks provided by PlotCollection. The Create your own chart with PlotCollection page covers the use of map more extensively.
PlotCollection also provides a method to automatically generate legends for the plots.
Warning
The API of the add_legend method is still quite experimental.
For properties that are shared for all variables, generating the legend is relatively straightforward as it is unique and using the dimension name as title is a sensible choice.
It is also possible however to have properties that depend on both the data variable and on one (or even multiple) dimensions. In such cases, like in the example below,
the legend becomes dependent on the data variable of interest, and we should probably generate two legends. Or none at all, note that here the color is mostly used to help us visually cluster the multiple KDEs the correspond the same coordinate values, not so much as a way to encode information.