Skip to content

plot_paths

Signature/Parameters

def plot_paths(self, exposure = None, outcome = None, adj_set = None, directed = False, show_full_dag = True, use_labels = True, title_fontsize = 10, figsize = (16, 9), path_color = 'black', **plot_kws)

Plot individual paths between exposure and outcome nodes.

Parameters:

Name Type Description Default
exposure str or list[str] or None

Exposure node(s) to anchor the paths. Defaults to the DAG exposure role when omitted.

None
outcome str or list[str] or None

Outcome node(s) serving as path targets. Defaults to the DAG outcome role when omitted.

None
adj_set str or Sequence[str] or None

Adjustment set used to assess path openness. Strings are promoted to single-element lists.

None
directed bool

If True, restrict to directed paths from exposure to outcome. Defaults to False.

False
show_full_dag bool

Draw the entire DAG in the background with muted styling before highlighting each path. Defaults to True.

True
use_labels bool

When True (default), prefer custom node labels over names.

True
title_fontsize int

Font size for subplot titles. Defaults to 10.

10
figsize tuple[float, float]

Size of the grid of path plots in inches. Defaults to (16, 9).

(16, 9)
path_color str

Color applied to highlighted path edges. Defaults to 'black'.

'black'
**plot_kws

Additional keyword arguments forwarded to DAG.plot for both the background DAG (when show_full_dag is True) and each path.

{}

Returns:

Type Description
list[Axes]

Axes objects for the generated subplots. The list is flattened even when the grid contains a single axis.

Examples:

>>> G = DAG(graph="X -> Z -> Y")
>>> axes = G.plot_paths(exposure="X", outcome="Y", directed=True, show_full_dag=False)
>>> len(axes)
1
Source code in causalinf/gcm.py
def plot_paths(self, exposure=None, outcome=None, adj_set=None, directed=False,
               show_full_dag = True,
               use_labels=True,
               title_fontsize = 10,
               figsize=(16, 9),
               path_color='black',
               **plot_kws
               ):
    """
    Plot individual paths between exposure and outcome nodes.

    Parameters
    ----------
    exposure : str or list[str] or None, optional
        Exposure node(s) to anchor the paths. Defaults to the DAG exposure
        role when omitted.
    outcome : str or list[str] or None, optional
        Outcome node(s) serving as path targets. Defaults to the DAG outcome
        role when omitted.
    adj_set : str or Sequence[str] or None, optional
        Adjustment set used to assess path openness. Strings are promoted to
        single-element lists.
    directed : bool, optional
        If ``True``, restrict to directed paths from exposure to outcome.
        Defaults to ``False``.
    show_full_dag : bool, optional
        Draw the entire DAG in the background with muted styling before
        highlighting each path. Defaults to ``True``.
    use_labels : bool, optional
        When ``True`` (default), prefer custom node labels over names.
    title_fontsize : int, optional
        Font size for subplot titles. Defaults to ``10``.
    figsize : tuple[float, float], optional
        Size of the grid of path plots in inches. Defaults to ``(16, 9)``.
    path_color : str, optional
        Color applied to highlighted path edges. Defaults to ``'black'``.
    **plot_kws :
        Additional keyword arguments forwarded to ``DAG.plot`` for both the
        background DAG (when ``show_full_dag`` is ``True``) and each path.

    Returns
    -------
    list[matplotlib.axes.Axes]
        Axes objects for the generated subplots. The list is flattened even
        when the grid contains a single axis.

    Examples
    --------
    >>> G = DAG(graph="X -> Z -> Y")
    >>> axes = G.plot_paths(exposure="X", outcome="Y", directed=True, show_full_dag=False)
    >>> len(axes)
    1
    """
    if show_full_dag:
        assert self.nodes_position, "Nodes position must be set when show_full_dag=True"


    default_usetex = plt.rcParams["text.usetex"] 
    plt.rcParams["text.usetex"] = True
    packages = ["amsmath", "amssymb", "siunitx", "bm", "wasysym", "marvosym"]
    plt.rcParams['text.latex.preamble'] = rf"\usepackage{{{', '.join(packages)}}}"

    adj_set = [adj_set] if isinstance(adj_set, str) else adj_set

    paths = self.paths(exposure=exposure, outcome=outcome, adj_set=adj_set, directed=directed)
    npaths = len(paths)
    ncols = int(math.ceil(math.sqrt(npaths)))
    nrows = int(math.ceil(npaths / ncols))
    fig, axs = plt.subplots(nrows, ncols, figsize=figsize, tight_layout=True)
    if ncols >1 or nrows>1:
        axs=axs.flatten()
    else:
        axs = [axs]
    [ax.axis('off') for ax in axs]
    # 

    pos = self.nodes_position
    roles = self.nodes_role
    nodes_label = self.nodes_label
    edge_label = self.edge_label
    for i, (path, info) in enumerate(paths.items()):
        ax = axs[i]

        show_labels=True
        if show_full_dag:
            self.plot(ax=ax, edge_color ='lightgray', **plot_kws)
            show_labels=False

        # G2 = DAG(path, nodes_role=roles, nodes_position=pos, nodes_label=nodes_label)
        G2 = self.__rebuild_graph__(path)
        G2.plot(ax=ax, edge_linewidth=3, show_labels=show_labels,
                edge_color=path_color, use_labels=use_labels, **plot_kws)
        adj = info['adj_set']
        if adj:
            adj = [self.nodes_label.get(x, x) for x in adj] if use_labels else adj
            adj = ', '.join(adj)
        else:
            adj = ""
        title = rf"Path is \textbf{{{'open' if info['open'] else 'closed'}}}; Adjustment set: "+"\{"+adj+"\}"
        ax.set_title(title, loc='left', fontsize=title_fontsize)
        ax.axis('on')
        plt.tight_layout()

    plt.rcParams["text.usetex"] = default_usetex
    return axs