#!/usr/bin/env python
# -*- coding: UTF-8 -*


"""Outil pour migration ScoDoc 7 => ScoDoc 8

 - Liste des appels de la forme context.xxx
./refactor.py showcontextcalls app/scodoc/*.p

 - remplace context.xxx par module.xxx

./refactor.py refactor module.method app/scodoc/*.py



Pour chaque module dans views:
    - construire la liste des fonctions définies dans ce module:
      get_module_functions

Pour chaque module dans views et dans scodoc:
    - remplacer context.xxx par app.views.M.xxx
    où M est le module de views définissant xxx
    Si xxx n'est pas trouvé, erreur !
"""


from __future__ import print_function
import re
from pprint import pprint as pp
import os
import sys
import types
import tempfile
import shutil
import click

TYPES_TO_SCAN = {
    types.FunctionType,
    # types.ClassType,
    # types.DictionaryType,
    # types.FloatType,
    # types.IntType,
    # types.ListType,
    # types.StringType,
    # types.TupleType,
}


def get_module_symbols(module):
    """returns list of symbols (functions and constants) defined in the given module"""
    return [
        f.__name__
        for f in [getattr(module, name) for name in dir(module)]
        if (type(f) in TYPES_TO_SCAN)
        and ((type(f) != types.FunctionType) or (f.__module__ == module.__name__))
    ]


# print("\n".join(f.__name__ for f in get_module_functions(notes)))


def scan_views_symbols():
    """Scan modules in app.views and returns
    { }
    """
    import app

    views_modules = [
        getattr(app.views, mod_name)
        for mod_name in dir(app.views)
        if type(getattr(app.views, mod_name)) == types.ModuleType
    ]
    sym2mod = {}  # symbole_name : module
    for module in views_modules:
        start = "app.views."
        assert module.__name__.startswith(start)
        module_name = module.__name__[len(start) :]
        symbols = set(get_module_symbols(module))
        print("%d symbols defined in %s" % (len(symbols), module))
        dups = symbols.intersection(sym2mod)
        if len(dups):
            print("duplicated symbols !")
            for dup in dups:
                print("%s:\t%s\t%s" % (dup, sym2mod[dup], module_name))

        sym2mod.update({s: module_name for s in symbols})
    return sym2mod


def replace_context_calls(sourcefilename, sym2mod):
    undefined_list = []  # noms de fonctions non présents dans les modules "views"

    def repl(m):
        funcname = m.group(1)
        module = sym2mod.get(funcname, False)
        if module:
            return module + "." + funcname
        else:
            undefined_list.append((sourcefilename, funcname))
            return m.group(0)  # leave unchanged

    print("reading %s" % sourcefilename)
    source = open(sourcefilename).read()
    exp = re.compile(r"context\.([a-zA-Z0-9_]+)")
    source2 = exp.sub(repl, source)
    return source2, undefined_list


# sym2mod = scan_views_symbols()

# source2, undefined_list = replace_context_calls("app/scodoc/sco_core.py", sym2mod)


def list_context_calls(sourcefilename):
    """List of methods called on context in this file"""
    source = open(sourcefilename).read()
    exp = re.compile(r"context\.([a-zA-Z0-9_]+)")
    return sorted(set(exp.findall(source)))


def get_context_calls(src_filenames):
    """returns { method_name : [ list of module names in which it is called ] }"""
    S = {}
    for sourcefilename in src_filenames:
        l = list_context_calls(sourcefilename)
        module_name = os.path.splitext(os.path.split(sourcefilename)[1])[0]
        for m in l:
            if m in S:
                S[m].append(module_name)
            else:
                S[m] = [module_name]
    return S


@click.group()
def cli():
    pass


@cli.command()
@click.argument("src_filenames", nargs=-1)
def showcontextcalls(src_filenames):
    click.echo("Appels de méthodes sur l'object context")
    S = get_context_calls(src_filenames)
    #
    for method in sorted(S.keys()):
        print(method + ":\t" + ", ".join(S[method]))


@cli.command()
@click.argument("modulemethod", nargs=1)
@click.argument("src_filenames", nargs=-1)
def refactor(modulemethod, src_filenames):
    """Replace call context.method(...)
    by module.method(context, ...)
    in all given source filenames
    """
    modulemethod = str(modulemethod)  # avoid unicode in Python2
    frags = modulemethod.split(".")
    if len(frags) < 2:
        raise click.BadParameter("must be module.method", param_hint="modulemethod")
    module = ".".join(frags[:-1])
    method = frags[-1]
    backup = tempfile.mkdtemp(dir="/tmp")
    for sourcefilename in src_filenames:
        source_module_name = os.path.splitext(os.path.split(sourcefilename)[1])[0]
        is_local = source_module_name == module
        source = open(sourcefilename).read()
        if not is_local:
            source2 = source.replace(
                "context." + method + "(", module + "." + method + "(context, "
            )
            source2 = source2.replace(
                "context.Notes." + method + "(", module + "." + method + "(context, "
            )
        else:
            # call in the same module:
            source2 = source.replace("context." + method + "(", method + "(context, ")
        if source2 != source:
            print("changed %s" % sourcefilename)
            shutil.move(sourcefilename, backup)
            open(sourcefilename, "w").write(source2)
    print("Done.\noriginal files saved in %s\n" % backup)


if __name__ == "__main__":
    try:
        cli(obj={})
    except SystemExit as e:
        if e.code != 0:
            raise