Skip to content

plot_equivalent_dags

Signature/Parameters

def plot_equivalent_dags(self, use_labels = True, show_labels = True, edge_difference_color = 'red', title_fontsize = 10, title_original_graph = 'Original Graph', title_equivalent_graph = 'Equivalent DAG', show_footnote = True, figsize = (16, 9), max_per_figure = 9, max_eq_dags = 27, **plot_kws)

Visualize multiple DAGs in the Markov equivalence class.

Parameters:

Name Type Description Default
use_labels bool

Prefer custom node labels when True (default).

True
show_labels bool

Display node labels on the plots. Defaults to True.

True
edge_difference_color str

Color used to highlight edges that differ from the original graph in each equivalent DAG. Defaults to 'red'.

'red'
title_fontsize int

Font size for subplot titles. Defaults to 10.

10
title_original_graph str

Title assigned to the baseline plot of the original DAG.

'Original Graph'
title_equivalent_graph str

Title applied to each equivalent DAG subplot.

'Equivalent DAG'
show_footnote bool

Display a numbered footnote beneath each subplot when True.

True
figsize tuple[float, float]

Figure size in inches for each panel grid. Defaults to (16, 9).

(16, 9)
max_per_figure int

Maximum number of panels per figure. Defaults to 9.

9
max_eq_dags int

Cap on the number of equivalent DAGs to display. Defaults to 27.

27
**plot_kws

Additional keyword arguments forwarded to DAG.plot.

{}

Returns:

Type Description
dict[int, list]

Mapping from figure index to [figure, axes_list] pairs. Returns None when no equivalent DAGs exist.

Examples:

>>> G = DAG(graph="X -> Z <- Y")
>>> figs = G.plot_equivalent_dags(show_footnote=False, max_eq_dags=4)
>>> isinstance(figs, dict)
True
Source code in causalinf/gcm.py
def plot_equivalent_dags(self,
                         use_labels=True,
                         show_labels=True,
                         edge_difference_color='red',
                         title_fontsize = 10,
                         title_original_graph = 'Original Graph',
                         title_equivalent_graph = "Equivalent DAG",
                         show_footnote = True,
                         figsize=(16, 9),
                         max_per_figure = 9,
                         max_eq_dags= 27,
                         **plot_kws
                         ):
    """
    Visualize multiple DAGs in the Markov equivalence class.

    Parameters
    ----------
    use_labels : bool, optional
        Prefer custom node labels when ``True`` (default).
    show_labels : bool, optional
        Display node labels on the plots. Defaults to ``True``.
    edge_difference_color : str, optional
        Color used to highlight edges that differ from the original graph
        in each equivalent DAG. Defaults to ``'red'``.
    title_fontsize : int, optional
        Font size for subplot titles. Defaults to ``10``.
    title_original_graph : str, optional
        Title assigned to the baseline plot of the original DAG.
    title_equivalent_graph : str, optional
        Title applied to each equivalent DAG subplot.
    show_footnote : bool, optional
        Display a numbered footnote beneath each subplot when ``True``.
    figsize : tuple[float, float], optional
        Figure size in inches for each panel grid. Defaults to ``(16, 9)``.
    max_per_figure : int, optional
        Maximum number of panels per figure. Defaults to ``9``.
    max_eq_dags : int, optional
        Cap on the number of equivalent DAGs to display. Defaults to ``27``.
    **plot_kws :
        Additional keyword arguments forwarded to ``DAG.plot``.

    Returns
    -------
    dict[int, list]
        Mapping from figure index to ``[figure, axes_list]`` pairs. Returns
        ``None`` when no equivalent DAGs exist.

    Examples
    --------
    >>> G = DAG(graph="X -> Z <- Y")
    >>> figs = G.plot_equivalent_dags(show_footnote=False, max_eq_dags=4)
    >>> isinstance(figs, dict)
    True
    """
    # collecting equivalent DAGs
    eq_dags = self.equivalent_dags()
    n_eq_dags = len(eq_dags)
    if n_eq_dags == 0:
        return None

    if n_eq_dags > max_eq_dags:
        print(f"\n**Note:**\n"+
              f"---------\n"
              f"Maximun number of equivalent DAGs to plot is set to {max_eq_dags}"+
              f" by default, but there are {n_eq_dags} equivalent DAGs. Some equivalent DAGs"+
              f" will be omitted. To change it, set 'max_eq_dags'.\n")

    max_eq_dags = np.min([n_eq_dags, max_eq_dags])
    figs = dict(self.__chunked_ranges__(max_eq_dags, max_per_figure))

    print(f"Total of equivalent DAGs: {n_eq_dags}\n"+
          f"Plotting {max_eq_dags} equivalent DAG(s)\n"
          f"Generating {len(figs.keys())} figure(s) with a maximum of {max_per_figure} panels per figure\n")
    figs_res = {}

    nodes_subset = plot_kws.pop("node_subset", None)
    legend_show = plot_kws.pop("legend_show", True)

    for fig_number, panels in figs.items():
        # figure
        ncols = int(math.ceil(math.sqrt(max_per_figure)))
        nrows = int(math.ceil(max_per_figure / ncols))
        fig, axs = plt.subplots(nrows, ncols, figsize=figsize, tight_layout=True)
        if ncols >1 or nrows>1:
            axs=axs.flatten()
        else:
            axs = [axs]
        [ax.axis('off') for ax in axs]

        # panels
        for panel, panel_number in enumerate(panels):
            print(f"Creating plot {panel_number+1} of {n_eq_dags}...", end='')
            ax = axs[panel]
            eq_dag = eq_dags[panels[panel]]
            panel_legend_show = legend_show and panel_number == 0
            # baseline plot
            eq_dag.plot(ax=ax,
                        node_subset = nodes_subset,
                        legend_show=panel_legend_show,
                        edge_linewidth=1,
                        show_labels=show_labels,
                        use_labels=use_labels,
                        title=title_equivalent_graph,
                        title_fontsize=title_fontsize,
                        **plot_kws)
            # superimpose edges highlighing the differences
            edges = self.edge_differences(eq_dag)['G2']
            nodes = self.__collect_nodes_from_edges__(edges)
            if nodes_subset is not None:
                nodes = list(set(nodes).intersection(nodes_subset))

            if nodes:
                eq_dag.plot(ax=ax, edge_linewidth=3,
                            node_subset = nodes,
                            edge_subset = edges,
                            legend_show=False,
                            show_labels=show_labels,
                            edge_color=edge_difference_color,
                            use_labels=use_labels,
                            title=title_equivalent_graph,
                            title_fontsize=title_fontsize,
                            **plot_kws)
            if show_footnote:
                # footnote
                xcoord=1
                ycoord=1.07
                yoffset=-.1
                fn = f"Equivalent DAG: {panel_number+1} of {n_eq_dags}"
                ax.annotate(fn, xy=(xcoord,yoffset), xytext=(xcoord,yoffset),
                            xycoords='axes fraction', size=11, ha='right',
                            style='italic', alpha=.6)
            print('done!')
            ax.axis('on')
            plt.tight_layout()
            figs_res[fig_number] = [fig, axs]
    return figs_res