arviz_plots.plot_forest#
- arviz_plots.plot_forest(dt, var_names=None, filter_vars=None, group='posterior', coords=None, sample_dims=None, combined=False, point_estimate=None, ci_kind=None, ci_probs=None, labels=None, shade_label=None, plot_collection=None, backend=None, labeller=None, aes_map=None, plot_kwargs=None, stats_kwargs=None, pc_kwargs=None)[source]#
Plot 1D marginal credible intervals in a single plot.
- Parameters:
- dt
datatree.DataTree
ordict
of {str
datatree.DataTree
} Input data. In case of dictionary input, the keys are taken to be model names. In such cases, a dimension “model” is generated and can be used to map to aesthetics.
plot_forest
uses the dimension “column” (creating it if necessary) to generate the grid then adds the intervals+point estimates to its “forest” coordinate and labels to its “labels” coordinates. The data used to plot is then the subsetcolumn="forest"
.- var_names
str
orlist
ofstr
, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot.
- filter_vars{
None
, “like”, “regex”}, defaultNone
If None, interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names.
- group
str
, default “posterior” Group to be plotted.
- coords
dict
, optional - sample_dims
str
or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to
rcParams["data.sample_dims"]
- combinedbool, default
False
Whether to plot intervals for each chain or not. Ignored when the “chain” dimension is not present.
- point_estimate{“mean”, “median”, “mode”}, optional
Which point estimate to plot. Defaults to
rcParams["plot.point_estimate"]
- ci_kind{“eti”, “hdi”}, optional
Which credible interval to use. Defaults to
rcParams["stats.ci_kind"]
- ci_probs(
float
,float
), optional Indicates the probabilities that should be contained within the plotted credible intervals. It should be sorted as the elements refer to the probabilities of the “trunk” and “twig” elements. Defaults to
(0.5, rcParams["stats.ci_prob"])
- labelssequence of
str
, optional Sequence with the dimensions to be labelled in the plot. By default all dimensions except “chain” and “model” (if present). The order of labels is ignored, only elements being present in it matters. It can include the special “__variable__” indicator, and does so by default.
- shade_label
str
, defaultNone
Element of labels that should be used to add shading horizontal strips to the plot. Note that labels and credible intervals are plotted in different plots. The shading is applied to both plots, and the spacing between them is set to 0 if possible, which is not always the case (one notable example being matplotlib’s constrained layout).
- plot_collection
PlotCollection
, optional - backend{“matplotlib”, “bokeh”}, optional
- labeller
labeller
, optional - aes_mapmapping of {
str
sequence ofstr
orFalse
}, optional Mapping of artists to aesthetics that should use their mapping in
plot_collection
when plotted. Valid keys are the same as for plot_kwargs except “ticklabels” which doesn’t apply and “twig” and “trunk” which, similarly to stats_kwargs take the same aesthetics through the “credible_interval” key.By default, aesthetic mappings are generated for: y, alpha, overlay and color (if multiple models are present). All aesthetic mappings but alpha are applied to both the credible intervals and the point estimate; overlay is applied to labels; and both overlay and alpha are applied to the shade.
“overlay” is a dummy aesthetic to trigger looping over variables and/or dimensions using all aesthetics in every iteration. “alpha” gets two values (0, 0.3) in order to trigger the alternate shading effect.
- plot_kwargsmapping of {
str
mapping orFalse
}, optional Valid keys are:
trunk, twig -> passed to
line_x
point_estimate -> passed to
scatter_x
labels -> passed to
annotate_label
shade -> passed to
fill_between_y
ticklabels -> passed to
xticks
remove_axis -> not passed anywhere, can only take
False
as value to skip callingremove_axis
- stats_kwargsmapping, optional
Valid keys are:
credible_interval -> passed to eti or hdi
point_estimate -> passed to mean, median or mode
- pc_kwargsmapping
Passed to
arviz_plots.PlotCollection.grid
- dt
- Returns:
Notes
The separation between variables and all its coordinate values is set to 1. The only two exceptions to this are the dimensions named “chain” and “model” in case they are present, which get a smaller spacing to give a sense of grouping among visual elements that only differ on their chain or model id.
Examples
The following examples focus on behaviour specific to
plot_forest
. For a general introduction to batteries-included functions like this one and common usage examples see Introduction to batteries-included plots in arviz_plotsDefault forest plot for a single model:
>>> from arviz_plots import plot_forest, style >>> style.use("arviz-clean") >>> from arviz_base import load_arviz_data >>> centered = load_arviz_data('centered_eight') >>> non_centered = load_arviz_data('non_centered_eight') >>> pc = plot_forest(centered)
Default forest plot for multiple models:
>>> pc = plot_forest({"centered": centered, "non centered": non_centered}) >>> pc.add_legend("model")
Single model forest plot with color mapped to the variable (mapping which is also applied to the labels) and alternate shading per school. Moreover, to ensure the shading looks continuous, we’ll specify we don’t want to use constrained layout (set by the “arviz-clean” theme) and to avoid having the labels too squished we’ll set the
width_ratios
forcreate_plotting_grid
viapc_kwargs
.>>> pc = plot_forest( >>> non_centered, >>> var_names=["theta", "mu", "theta_t", "tau"], >>> pc_kwargs={ >>> "aes": {"color": ["__variable__"]}, >>> "plot_grid_kws": {"width_ratios": [1, 2], "layout": "none"} >>> }, >>> aes_map={"labels": ["color"]}, >>> shade_label="school", >>> )
Extend the forest plot with an extra plot with ess estimates. To achieve that, we manually add a “column” dimension with size 3.
plot_forest
only plots on the “labels” and “forest” coordinate values, leaving the “ess” coordinate empty. Afterwards, we manually usePlotCollection.map
with the ess result as data on the “ess” column to plot their values.>>> from arviz_plots import visuals >>> from arviz_stats.base import ess >>> >>> c_aux = centered["posterior"].expand_dims( >>> column=3 >>> ).assign_coords(column=["labels", "forest", "ess"]) >>> pc = plot_forest(c_aux, combined=True) >>> pc.map( >>> visuals.scatter_x, "ess", data=ess(centered), >>> coords={"column": "ess"}, color="C0" >>> )
Note that we are using the same
PlotCollection
, so when usingmap
all the same aesthetic mappings used byplot_forest
are used.