"""
Functions for easy script writing
mostly for simpler figure saving from cli
"""

import matplotlib.pyplot as plt
import numpy as np
import os.path as path
from argparse import ArgumentParser

def ArgumentParserWithFigure(*args, **kwargs):
    parser = ArgumentParser(*args, **kwargs)
    parser.add_argument("fname", metavar="path/to/figure[/]", nargs="?", help="Location for generated figure, will append __file__ if a directory. If not supplied, figure is shown.")

    return parser

def save_all_figs_to_path_or_show(fnames, figs=None, default_basename=None, default_extensions=['.pdf', '.png']):
    """
    Save all figures to fnames.
    If fnames is empty, simply call plt.show()
    """
    if not fnames:
        # empty list, False, None
        plt.show()
        return

    if figs is None:
        figs = [plt.figure(i) for i in plt.get_fignums()]

    default_basename = path.basename(default_basename)

    # singular value
    if isinstance(fnames, (str, True)):
        fnames = [fnames]

    if len(fnames) == len(figs):
        fnames_list = zip(figs, fnames, False)
    elif len(fnames) == 1:
        tmp_fname = fnames[0] #needed for generator
        fnames_list = ( (fig, tmp_fname, len(figs) > 1) for fig in figs)
    else:
        # outer product magic
        fnames_list = ( (fig,fname, False) for fname in fnames for fig in figs )
    del fnames
    # format fnames
    pad_width = max(2, int(np.floor(np.log10(len(figs))+1)))

    fig_fnames = []
    for fig, fnames, append_num in fnames_list:
        if not hasattr(fnames, '__len__') or isinstance(fnames, str):
            # single name
            fnames = [fnames]

        new_fnames = []
        for fname in fnames:
            if path.isdir(fname):
                if default_basename is not None:
                    fname = path.join(fname, path.splitext(default_basename)[0]) # leave off extension

                elif hasattr(fig, 'basefilename'):
                    fname = path.join(fname, path.splitext(fig.basefilename)[0]) # leave off extension

            if append_num is True:
               fname += ("_fig{:0"+str(pad_width)+"d}").format(fig.number)

            if not path.splitext(fname)[1]: # no extension
                for ext in default_extensions:
                    new_fnames.append(fname+ext)
            else:
                new_fnames.append(fname)

        fig_fnames.append(new_fnames)

    # save files
    for fnames, fig in zip(fig_fnames, figs):
        for fname in fnames:
            fig.savefig(fname, transparent=True)