# -*- mode: python -*-
# -*- coding: utf-8 -*-

##############################################################################
#
# Gestion scolarite IUT
#
# Copyright (c) 1999 - 2021 Emmanuel Viennet.  All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#   Emmanuel Viennet      emmanuel.viennet@viennet.net
#
##############################################################################

"""Fonctions de calcul des moyennes d'UE
"""
import numpy as np
import pandas as pd

from app import db
from app import models
from app.models import UniteEns, Module, ModuleImpl, ModuleUECoef
from app.comp import moy_mod
from app.models.formsemestre import FormSemestre
from app.scodoc import sco_codes_parcours


def df_load_module_coefs(formation_id: int, semestre_idx: int = None) -> pd.DataFrame:
    """Charge les coefs des modules de la formation pour le semestre indiqué.

    Ces coefs lient les modules à chaque UE.

    Résultat: (module_coefs_df, ues, modules)
        DataFrame rows = UEs, columns = modules, value = coef.

    Considère toutes les UE (sauf sport) et modules du semestre.
    Les coefs non définis (pas en base) sont mis à zéro.

    Si semestre_idx None, prend toutes les UE de la formation.
    """
    ues = UniteEns.query.filter_by(formation_id=formation_id).filter(
        UniteEns.type != sco_codes_parcours.UE_SPORT
    )
    modules = Module.query.filter_by(formation_id=formation_id).order_by(
        Module.semestre_id, Module.numero
    )
    if semestre_idx is not None:
        ues = ues.filter_by(semestre_idx=semestre_idx)
        modules = modules.filter_by(semestre_id=semestre_idx)
    ues = ues.all()
    modules = modules.all()
    ue_ids = [ue.id for ue in ues]
    module_ids = [module.id for module in modules]
    module_coefs_df = pd.DataFrame(columns=module_ids, index=ue_ids, dtype=float)
    for mod_coef in (
        db.session.query(ModuleUECoef)
        .filter(UniteEns.formation_id == formation_id)
        .filter(ModuleUECoef.ue_id == UniteEns.id)
    ):
        module_coefs_df[mod_coef.module_id][mod_coef.ue_id] = mod_coef.coef
    module_coefs_df.fillna(value=0, inplace=True)
    return module_coefs_df, ues, modules


def df_load_modimpl_coefs(
    formsemestre: models.FormSemestre, ues=None, modimpls=None
) -> pd.DataFrame:
    """Charge les coefs des modules du formsemestre indiqué.

    Comme df_load_module_coefs mais prend seulement les UE
    et modules du formsemestre.
    Si ues et modimpls sont None, prend tous ceux du formsemestre.
    Résultat: (module_coefs_df, ues, modules)
        DataFrame rows = UEs, columns = modimpl, value = coef.
    """
    if ues is None:
        ues = formsemestre.query_ues().all()
    ue_ids = [x.id for x in ues]
    if modimpls is None:
        modimpls = formsemestre.modimpls.all()
    modimpl_ids = [x.id for x in modimpls]
    mod2impl = {m.module.id: m.id for m in modimpls}
    modimpl_coefs_df = pd.DataFrame(columns=modimpl_ids, index=ue_ids, dtype=float)
    mod_coefs = (
        db.session.query(ModuleUECoef)
        .filter(ModuleUECoef.module_id == ModuleImpl.module_id)
        .filter(ModuleImpl.formsemestre_id == formsemestre.id)
    )

    for mod_coef in mod_coefs:
        modimpl_coefs_df[mod2impl[mod_coef.module_id]][mod_coef.ue_id] = mod_coef.coef
    modimpl_coefs_df.fillna(value=0, inplace=True)
    return modimpl_coefs_df, ues, modimpls


def notes_sem_assemble_cube(modimpls_notes: list[pd.DataFrame]) -> np.ndarray:
    """Réuni les notes moyennes des modules du semestre en un "cube"

    modimpls_notes : liste des moyennes de module
                     (DataFrames rendus par compute_module_moy, (etud x UE))
    Resultat: ndarray (etud x module x UE)
    """
    modimpls_notes_arr = [df.values for df in modimpls_notes]
    modimpls_notes = np.stack(modimpls_notes_arr)
    # passe de (mod x etud x ue) à (etud x mod x UE)
    return modimpls_notes.swapaxes(0, 1)


def notes_sem_load_cube(formsemestre):
    """Calcule le cube des notes du semestre
    (charge toutes les notes, calcule les moyenne des modules
    et assemble le cube)
    Resultat: ndarray (etuds x modimpls x UEs)
    """
    modimpls_evals_poids = {}  # modimpl.id : evals_poids
    modimpls_evals_notes = {}  # modimpl.id : evals_notes
    modimpls_evaluations = {}  # modimpl.id : liste des évaluations
    modimpls_notes = []
    for modimpl in formsemestre.modimpls:
        evals_notes, evaluations = moy_mod.df_load_modimpl_notes(modimpl.id)
        evals_poids, ues = moy_mod.df_load_evaluations_poids(modimpl.id)
        etuds_moy_module = moy_mod.compute_module_moy(
            evals_notes, evals_poids, evaluations
        )
        modimpls_evals_poids[modimpl.id] = evals_poids
        modimpls_evals_notes[modimpl.id] = evals_notes
        modimpls_evaluations[modimpl.id] = evaluations
        modimpls_notes.append(etuds_moy_module)
    return (
        notes_sem_assemble_cube(modimpls_notes),
        modimpls_evals_poids,
        modimpls_evals_notes,
        modimpls_evaluations,
    )


def compute_ue_moys(
    sem_cube: np.array,
    etuds: list,
    modimpls: list,
    ues: list,
    modimpl_inscr_df: pd.DataFrame,
    modimpl_coefs_df: pd.DataFrame,
) -> pd.DataFrame:
    """Calcul de la moyenne d'UE
    La moyenne d'UE est un nombre (note/20), ou NI ou NA ou ERR
    NI non inscrit à (au moins un) module de cette UE
    NA pas de notes disponibles
    ERR erreur dans une formule utilisateur. [XXX pas encore gérées ici]

    sem_cube: notes moyennes aux modules
                ndarray (etuds x modimpls x UEs)
                (floats avec des NaN)
    etuds : lites des étudiants (dim. 0 du cube)
    modimpls : liste des modules à considérer (dim. 1 du cube)
    ues : liste des UE (dim. 2 du cube)
    module_inscr_df: matrice d'inscription du semestre (etud x modimpl)
    module_coefs_df: matrice coefficients (UE x modimpl)

    Resultat: DataFrame columns UE, rows etudid
    """
    nb_etuds, nb_modules, nb_ues = sem_cube.shape
    assert len(etuds) == nb_etuds
    assert len(modimpls) == nb_modules
    assert len(ues) == nb_ues
    assert modimpl_inscr_df.shape[0] == nb_etuds
    assert modimpl_inscr_df.shape[1] == nb_modules
    assert modimpl_coefs_df.shape[0] == nb_ues
    assert modimpl_coefs_df.shape[1] == nb_modules
    modimpl_inscr = modimpl_inscr_df.values
    modimpl_coefs = modimpl_coefs_df.values
    # Duplique les inscriptions sur les UEs:
    modimpl_inscr_stacked = np.stack([modimpl_inscr] * nb_ues, axis=2)

    # Enlève les NaN du numérateur:
    # si on veut prendre en compte les module avec notes neutralisées ?
    # sem_cube_no_nan = np.nan_to_num(sem_cube, nan=0.0)

    # Ne prend pas en compte les notes des étudiants non inscrits au module:
    # Annule les notes:
    sem_cube_inscrits = np.where(modimpl_inscr_stacked, sem_cube, 0.0)
    # Annule les coefs des modules où l'étudiant n'est pas inscrit:
    modimpl_coefs_etuds = np.where(
        modimpl_inscr_stacked, np.stack([modimpl_coefs.T] * nb_etuds), 0.0
    )

    #
    # Version vectorisée
    #
    etud_moy_ue = np.sum(modimpl_coefs_etuds * sem_cube_inscrits, axis=1) / np.sum(
        modimpl_coefs_etuds, axis=1
    )
    return pd.DataFrame(
        etud_moy_ue, index=modimpl_inscr_df.index, columns=modimpl_coefs_df.index
    )