Skip to content

summary

Signature/Parameters

class summary
def __init__(self, model = None, model_name = 'Model 1', compare = None, output = 'text', style = None, omit = None, show_sig = True, show_se = False, show_ci = True, show_fit = True, digits = 4, digits_fit = 2, col_width = 1000, col_width_term = 15, latex_kws = None, fn = None, save_style = 'concise', save_copies = ['csv', 'xlsx'], *args, **kws)

model (causalinf..estimate) An estimate object from causalinf

compare (dict or list) A list of dict of other causalinf estimate objects. The estimates will be shown in different columns. If a dictionary is used, the keys will be used as the column names. For the column name of the object calling the summary, use ‘model_name’. If a list is provided, names will be set to “Model 1”, “Model 2”, etc.

output (str) Format of the output: - ‘text’: returns None and print summary - ‘tibble’: returns a tibble - ‘latex’: returns a latex table To save the file, use ‘fn’. The output and the saved version Are independent. For instance, it is possible to print the summary (text) and at the same time save it in latex using fn.

model_name (str) Name of the column showing the estimates when output is ‘tibble’ of latex.

style (str) If style=’concise’, the summary table returns only - The parameter name (‘term’) - The confidence interval (if show_ci=True) and the std.errors (if show_se=True) and the p-value indicator (if show_sig=True) If style=’full’, the summary table includes all estimation statsitics available. Defaults: - ‘full’ when compare=None and output=’text’ - ‘concise’ otherwise

omit (str) A regular expression to match elements in the column terms. Matched cases will be omitted.

save_style (str) Same as ‘style’, but to save the summary in a files based on ‘fn’ and ‘save_copies’

fn (str) Path with the filename to save the output in a file. Relative paths are alowed. It automatically save the type of output based on the filnename extension (tex, xlsx, xls, csv). Copies are saved based on ‘save_copies’.

save_copies (list) List of strings with the extensions to save copies of the output table in the format of the extensions provided. Available are xls, xlsx, and csv. If None, it will not save copies of the output.

show_fit (bool or list) If False, omit fit statistics; If True, shows the stats listed in causalinf.estimate.est.fit; else, shows the statistics included in the list provided

show_sig, show_se, show_ci (bool) When comparing models, show_* can be used to set which information such as standard errors (se), confidence intervals (ci), significance level indicators (sig), whenever available, appears alongside the parameter estimates. This is ignored when output=’text’ and compare=’None’

digits, digits_fit: (int) Digits to show in the estimates and fit statistics, respectively.

col_width: int Length of the column widths in the printed summary

latex_kws : Keywords from tibble.to_latex()

Source code in causalinf/utils.py
class summary:
    """
    model (causalinf.<strategy>.estimate)
        An estimate object from causalinf

    compare (dict or list)
        A list of dict of other causalinf estimate objects. The
        estimates will be shown in different columns.
        If a dictionary is used, the keys will be used as the column names.
        For the column name of the object calling the summary, use
        'model_name'.
        If a list is provided, names will be set to "Model 1", "Model 2", etc.

    output (str)
        Format of the output:
        - 'text': returns None and print summary
        - 'tibble': returns a tibble
        - 'latex': returns a latex table
        To save the file, use 'fn'. The output and the saved version
        Are independent. For instance, it is possible to print the
        summary (text) and at the same time save it in latex using
        fn.

    model_name (str)
        Name of the column showing the estimates when output is
        'tibble' of latex. 

    style (str)
        If style='concise', the summary table returns only
        - The parameter name ('term')
        - The confidence interval (if show_ci=True) and the std.errors
          (if show_se=True) and the p-value indicator (if show_sig=True)
        If style='full', the summary table includes all estimation
        statsitics available.
        Defaults:
        - 'full' when compare=None and output='text'
        - 'concise' otherwise

    omit (str)
        A regular expression to match elements in the column terms.
        Matched cases will be omitted.

    save_style  (str)
       Same as 'style', but to save the summary in a files based on
       'fn' and 'save_copies'

    fn (str)
        Path with the filename to save the output in a file. Relative
        paths are alowed. It automatically save the type of output
        based on the filnename extension (tex, xlsx, xls, csv). Copies
        are saved based on 'save_copies'.

    save_copies (list)
        List of strings with the extensions to save copies of the output
        table in the format of the extensions provided. Available are
        xls, xlsx, and csv.
        If None, it will not save copies of the output.


    show_fit (bool or list)
        If False, omit fit statistics; If True, shows the stats listed in
        causalinf.estimate.est.fit; else, shows the statistics
        included in the list provided

    show_sig, show_se, show_ci (bool)
        When comparing models, show_* can be used to set which information
        such as standard errors (se), confidence intervals (ci),
        significance level indicators (sig), whenever available,
        appears alongside the parameter estimates. This is ignored
        when output='text' and compare='None'

    digits, digits_fit: (int)
        Digits to show in the estimates and fit statistics, respectively.

    col_width: int
        Length of the column widths in the printed summary

    latex_kws :
        Keywords from tibble.to_latex()

    """

    def __init__(self,
                 model=None,
                 model_name = 'Model 1',
                 compare=None,
                 output = 'text',
                 style = None,
                 omit = None,
                 show_sig = True,
                 show_se =  False,
                 show_ci =  True,
                 show_fit = True,
                 digits = 4,
                 digits_fit = 2,
                 col_width = 1000,
                 col_width_term = 15,
                 # latex args
                 latex_kws=None,
                 fn = None,
                 save_style = 'concise', 
                 save_copies = ['csv', 'xlsx'],
                 *args, **kws
                 ):
        # compare = kws.get("compare", None)
        assert isinstance(latex_kws, dict | None), "'latex_kws' must be None or a dict."
        assert isinstance(save_copies, list | None), "'save_copies' must be None or a list of file extensions."

        self.model_name = model_name
        self.output = output
        self.style = style or self.get_style(compare, output)
        self.omit = omit
        self.digits = digits
        self.digits_fit = digits_fit
        self.show_sig=show_sig
        self.show_se=show_se
        self.show_ci=show_ci
        self.show_fit = show_fit # self.show_fit = kws.get("show_fit", True)
        self.latex_kws = latex_kws or {}
        self.fn = fn
        self.save_style = save_style
        self.save_copies = save_copies
        self.col_width = col_width
        self.col_width_term = col_width_term

        self.outcome = model.outcome
        self.exposure = model.exposure

        self.collect_models(model, model_name, compare)
        self.merge_models()
        self.collect_info()

        # # implicit parameters
        self.id_strategy = kws.get("id_strategy", '')
        self.formula = kws.get("formula", '')
        self.latex_replace = kws.get("latex_replace", None) ## for latex only
        self.estimator = model.est.fit['Estimator']
        self.footnote_added = False # used to avoid duplicating footnote entries

        # keep this order
        self._save(fn=self.fn, silent=False)
        self._save_copies()
        self._output(self.style)

    def collect_models(self, model, model_name, compare):
        assert isinstance(compare, list | dict | None), "'compare' must be a list of dict."
        model = {model_name: {'parameters': self.collect_summary_tidy_formatted(model.est.parameters, model_name),
                              'parameters_full': self.collect_summary_tidy(model.est.parameters),
                              'fit':model.est.fit,
                              'fit_tidy':model.est.fit_tidy(colname=model_name, digits=self.digits_fit),
                              'info':model.est.info
                              }}
        if not compare:
            compare = {}
        else:
            if isinstance(compare, list):
                model_number_init = 2
                model_number_end = model_number_init + len(compare) + 2 # 2 b/c col 'term' (+1) and range(2, k)=[2,k-1] (needs +1)
                model_names = [f"Model {i}" for i in range(model_number_init, model_number_end)] 
                compare = {model_name:model for model_name, model in zip(model_names, compare)}

            compare = {model_name:{'parameters':self.collect_summary_tidy_formatted(model.est.parameters, model_name),
                                   'parameters_full': self.collect_summary_tidy(model.est.parameters),
                                   'fit': model.est.fit,
                                   'fit_tidy':model.est.fit_tidy(colname=model_name, digits=self.digits_fit),
                                   'info':model.est.info
                                   } for model_name, model in compare.items()}

        models = model | compare
        self.models = models

    def collect_summary_tidy(self, parameters):
        parameters = (
            parameters
            .mutate(estimate = tp.col('estimate').round(self.digits),
                    se = tp.round('se', self.digits),
                    lo = tp.round('lo', self.digits),
                    hi = tp.round('hi', self.digits),
                    statistic = tp.round('statistic', self.digits),
                    pvalue = tp.round('pvalue', self.digits),
                    # ci = tp.map(['lo', 'hi'], lambda col: f"({col[0]}, {col[1]})")
                    )
        )
        if self.omit:
            parameters = (
                parameters
                .filter(~tp.col("term").str.contains(self.omit))
            )
        return parameters

    def collect_summary_tidy_formatted(self, parameters, model_name):
        cols_unite = ['estimate']
        if self.show_sig:
            cols_unite += ['sig']
        if self.show_se:
            cols_unite += ['se']
        if self.show_ci:
            cols_unite += ['ci']

        parameters = (
            parameters
            .mutate(estimate = tp.col('estimate').round(self.digits),
                    ci = tp.map(['lo', 'hi'], lambda col:
                                f"({col[0]:.{self.digits}}, {col[1]:.{self.digits}})"),
                    se = tp.map(['se'], lambda col: f"({col[0]:.{self.digits}})"))
            .select('term', cols_unite)
            .unite(model_name, cols_unite, sep=' ')
            .mutate(**{model_name : tp.str_replace_all(model_name, r'\(', '\n(')})
            .mutate(**{model_name : tp.str_replace_all(model_name, r' \*', '*')}) 
        )

        return parameters

    def merge_models(self):
        concise_parameter = tp.tibble()
        concise_fit_stats = tp.tibble()
        full = tp.tibble()
        for model_name, summary in self.models.items():
            fit_stats = summary['fit_tidy']
            concise_parameter = self.merge_models_concise(concise_parameter, summary['parameters'])
            concise_fit_stats = self.merge_models_concise(concise_fit_stats, fit_stats)
            full = self.merge_models_full(full, summary['parameters_full'], model_name, fit_stats)
        concise = concise_parameter.bind_rows(concise_fit_stats)
        self.merged = {'concise':concise, 'full':full}

    def merge_models_concise(self, base, to_merge):
        if base.nrow>0:
            base = base.full_join(to_merge, on='term', suffix='_right')
            merged = (base
                      .replace_null({'term':'', 'term_right':''})
                      .mutate(term = tp.case_when(tp.col('term')!=tp.col('term_right'),
                                                  tp.col('term')+tp.col('term_right'),
                                                  True, tp.col('term')
                                                  ))
                      .drop('term_right')
                      )
        else:
            merged = to_merge

        return merged

    def merge_models_full(self, base, to_merge, model_name, fit_stats):
        if self.show_fit:
            to_merge = to_merge.mutate(estimate = tp.as_character('estimate'))
            to_merge = to_merge.bind_rows(fit_stats.rename({model_name:'estimate'}))

        cols = to_merge.names
        to_merge = (to_merge
                    .mutate(Id = model_name)
                    .select('Id', cols)
                    )
        merged = base.bind_rows(to_merge)
        return merged

    def collect_info(self):
        info = []
        for model_name, model_info in self.models.items():
            info += [f"{model_name}: {model_info['info']}"]
        self.info = '; '.join(info)

    def _output(self, style):
        self.res_latex = self._output_latex(style)
        self.res_tibble = self._output_tibble(style)

        if self.output=='text':
            self.res = None
            self._output_text(style)
        elif self.output=='tibble':
            self.res = self.res_tibble 
        elif self.output=='latex':
            self.res = self.res_latex 

    def _output_text(self, style):
        tab = self.merged[style].to_pandas().fillna('--')
        tab = self.merged[style]
        if 'Id' in tab.names:
            if tab.pull('Id').unique().len()==1:
                tab = tab.drop('Id')

        line = "="*80
        print(dedent(f"""
        {line}
        Model: {self.model_name}
        Identification: {self.id_strategy}
        Outcome: {self.outcome}
        Exposure: {', '.join(self.exposure or 'Not set')}
        Formula: {self.formula}
        Summary:
        --------\
        """))
        tab = tab.mutate(term = tp.map(['term'], lambda col: [*col][0][:self.col_width_term]))
        printDF(tab.to_polars().with_columns(pl.all().cast(pl.Utf8)).fill_null('--'),
                col_width=self.col_width)
        print(line)
        print(get_stars(outcome='codes'))
        if self.info is not None:
            print(textwrap.fill(self.info, 80) )

        return ''

    def _output_tibble(self, style):
        return  self.merged[style]

    def _output_latex(self, style, *args, **kws):
        tab = self.merged[style]
        latex_kws = self.latex_kws

        footnotes = latex_kws.get("footnotes", {'l':[]})
        for align, note in footnotes.items():
            note = note if isinstance(note, list) else [note]
            if align=='l' and not self.footnote_added:
                footnotes[align] =  note + [self.info] + [ut.get_stars(outcome='codes', latex=True)]
            self.footnote_added = True
        latex_kws['footnotes'] = footnotes 

        if not latex_kws.get("align", None):
            ncols = tab.ncol
            latex_kws['align'] = f"l{'c'*(ncols-1)}"

        if self.latex_replace:
            tab = tab.replace({'term':self.latex_replace}, regex=True)

        tab = tp.from_polars(tab.to_polars().with_columns(pl.all().cast(pl.Utf8)))
        tab = tab.rename({"term": ' '}).to_latex(**latex_kws)
        return tab

    def get_style(self, compare, output):
        if compare is None and output == 'text':
            res = 'full'
        else:
            res = 'concise'
        return res

    def _save(self, fn, silent=False, *args, **kws):
        excel = ['.xls', '.xlsx']
        if fn:
            self._output(self.save_style)
            _, ext = os.path.splitext(fn)
            fn = os.path.expanduser(fn)

            assert ext, "The file name in 'fn' is missing the extension."

            print(f"Saving summary in {ext} ...", end="") if not silent else None
            if ext=='.tex':
                tab = self.res_latex
                with open(fn , 'w', encoding='utf-8') as file:
                    file.write(tab)

            elif ext in excel:
                tab = self.res_tibble.rename({"term": ' '})
                tab.to_excel(fn)

            elif ext == '.csv':
                tab = self.res_tibble.rename({"term": ' '})
                tab.to_csv(fn)

            print('done!')  if not silent else None

    def _save_copies(self):
        copies = self.save_copies
        if copies and self.fn:
            for ext in copies:
                print(f"Saving copy of summary in .{ext} ...", end="")
                fn, _ = os.path.splitext(self.fn)
                fn = os.path.expanduser(f"{fn}.{ext}")
                self._save(fn=fn, silent=True)
                print('done!')

    def __repr__(self):
        return ''