# -*- mode: python -*-
# -*- coding: utf-8 -*-

##############################################################################
#
# Gestion scolarite IUT
#
# Copyright (c) 1999 - 2022 Emmanuel Viennet.  All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#   Emmanuel Viennet      emmanuel.viennet@viennet.net
#
##############################################################################

"""Fonctions de calcul des moyennes d'UE (classiques ou BUT)
"""
import numpy as np
import pandas as pd

from app import db
from app import models
from app.models import UniteEns, Module, ModuleImpl, ModuleUECoef
from app.comp import moy_mod
from app.models.formsemestre import FormSemestre
from app.scodoc import sco_codes_parcours
from app.scodoc import sco_preferences
from app.scodoc.sco_codes_parcours import UE_SPORT
from app.scodoc.sco_utils import ModuleType


def df_load_module_coefs(formation_id: int, semestre_idx: int = None) -> pd.DataFrame:
    """Charge les coefs APC des modules de la formation pour le semestre indiqué.

    En APC, ces coefs lient les modules à chaque UE.

    Résultat: (module_coefs_df, ues_no_bonus, modules)
        DataFrame rows = UEs, columns = modules, value = coef.

    Considère toutes les UE sauf bonus et tous les modules du semestre.
    Les coefs non définis (pas en base) sont mis à zéro.

    Si semestre_idx None, prend toutes les UE de la formation.
    """
    ues = (
        UniteEns.query.filter_by(formation_id=formation_id)
        .filter(UniteEns.type != sco_codes_parcours.UE_SPORT)
        .order_by(UniteEns.semestre_idx, UniteEns.numero, UniteEns.acronyme)
    )
    modules = (
        Module.query.filter_by(formation_id=formation_id)
        .filter(
            (Module.module_type == ModuleType.RESSOURCE)
            | (Module.module_type == ModuleType.SAE)
            | (
                (Module.ue_id == UniteEns.id)
                & (UniteEns.type == sco_codes_parcours.UE_SPORT)
            )
        )
        .order_by(
            Module.semestre_id, Module.module_type.desc(), Module.numero, Module.code
        )
    )
    if semestre_idx is not None:
        ues = ues.filter_by(semestre_idx=semestre_idx)
        modules = modules.filter_by(semestre_id=semestre_idx)
    ues = ues.all()
    modules = modules.all()
    ue_ids = [ue.id for ue in ues]
    module_ids = [module.id for module in modules]
    module_coefs_df = pd.DataFrame(columns=module_ids, index=ue_ids, dtype=float)
    query = (
        db.session.query(ModuleUECoef)
        .filter(UniteEns.formation_id == formation_id)
        .filter(ModuleUECoef.ue_id == UniteEns.id)
    )
    if semestre_idx is not None:
        query = query.filter(UniteEns.semestre_idx == semestre_idx)

    for mod_coef in query:
        if mod_coef.module_id in module_coefs_df:
            module_coefs_df[mod_coef.module_id][mod_coef.ue_id] = mod_coef.coef
        # silently ignore coefs associated to other modules (ie when module_type is changed)

    # Initialisation des poids non fixés:
    # 0 pour modules normaux, 1. pour bonus (car par défaut, on veut qu'un bonus agisse
    # sur toutes les UE)
    default_poids = {
        mod.id: 1.0
        if (mod.module_type == ModuleType.STANDARD) and (mod.ue.type == UE_SPORT)
        else 0.0
        for mod in modules
    }

    module_coefs_df.fillna(value=default_poids, inplace=True)

    return module_coefs_df, ues, modules


def df_load_modimpl_coefs(
    formsemestre: models.FormSemestre, ues=None, modimpls=None
) -> pd.DataFrame:
    """Charge les coefs APC des modules du formsemestre indiqué.

    Comme df_load_module_coefs mais prend seulement les UE
    et modules du formsemestre.
    Si ues et modimpls sont None, prend tous ceux du formsemestre (sauf ue bonus).
    Résultat: (module_coefs_df, ues, modules)
        DataFrame rows = UEs (sans bonus), columns = modimpl, value = coef.
    """
    if ues is None:
        ues = formsemestre.query_ues().all()
    ue_ids = [x.id for x in ues]
    if modimpls is None:
        modimpls = formsemestre.modimpls_sorted
    modimpl_ids = [x.id for x in modimpls]
    mod2impl = {m.module.id: m.id for m in modimpls}
    modimpl_coefs_df = pd.DataFrame(columns=modimpl_ids, index=ue_ids, dtype=float)
    mod_coefs = (
        db.session.query(ModuleUECoef)
        .filter(ModuleUECoef.module_id == ModuleImpl.module_id)
        .filter(ModuleImpl.formsemestre_id == formsemestre.id)
    )

    for mod_coef in mod_coefs:
        try:
            modimpl_coefs_df[mod2impl[mod_coef.module_id]][
                mod_coef.ue_id
            ] = mod_coef.coef
        except IndexError:
            # il peut y avoir en base des coefs sur des modules ou UE qui ont depuis été retirés de la formation
            pass
    # Initialisation des poids non fixés:
    # 0 pour modules normaux, 1. pour bonus (car par défaut, on veut qu'un bonus agisse
    # sur toutes les UE)
    default_poids = {
        modimpl.id: 1.0
        if (modimpl.module.module_type == ModuleType.STANDARD)
        and (modimpl.module.ue.type == UE_SPORT)
        else 0.0
        for modimpl in formsemestre.modimpls_sorted
    }

    modimpl_coefs_df.fillna(value=default_poids, inplace=True)
    return modimpl_coefs_df, ues, modimpls


def notes_sem_assemble_cube(modimpls_notes: list[pd.DataFrame]) -> np.ndarray:
    """Réuni les notes moyennes des modules du semestre en un "cube"

    modimpls_notes : liste des moyennes de module
                     (DataFrames rendus par compute_module_moy, (etud x UE))
    Resultat: ndarray (etud x module x UE)
    """
    assert len(modimpls_notes)
    modimpls_notes_arr = [df.values for df in modimpls_notes]
    modimpls_notes = np.stack(modimpls_notes_arr)
    # passe de (mod x etud x ue) à (etud x mod x ue)
    return modimpls_notes.swapaxes(0, 1)


def notes_sem_load_cube(formsemestre: FormSemestre) -> tuple:
    """Construit le "cube" (tenseur) des notes du semestre.
    Charge toutes les notes (sql), calcule les moyennes des modules
    et assemble le cube.

    etuds: tous les inscrits au semestre (avec dem. et def.)
    modimpls: _tous_ les modimpls de ce semestre (y compris bonus sport)
    UEs: toutes les UE du semestre (même si pas d'inscrits) SAUF le sport.

    Attention: la liste des modimpls inclut les modules des UE sport, mais
    elles ne sont pas dans la troisième dimension car elles n'ont pas de
    "moyenne d'UE".

    Résultat:
        sem_cube : ndarray (etuds x modimpls x UEs)
        modimpls_evals_poids dict { modimpl.id : evals_poids }
        modimpls_results dict { modimpl.id : ModuleImplResultsAPC }
    """
    modimpls_results = {}
    modimpls_evals_poids = {}
    modimpls_notes = []
    for modimpl in formsemestre.modimpls_sorted:
        mod_results = moy_mod.ModuleImplResultsAPC(modimpl)
        evals_poids, _ = moy_mod.load_evaluations_poids(modimpl.id)
        etuds_moy_module = mod_results.compute_module_moy(evals_poids)
        modimpls_results[modimpl.id] = mod_results
        modimpls_evals_poids[modimpl.id] = evals_poids
        modimpls_notes.append(etuds_moy_module)
    if len(modimpls_notes):
        cube = notes_sem_assemble_cube(modimpls_notes)
    else:
        nb_etuds = formsemestre.etuds.count()
        cube = np.zeros((nb_etuds, 0, 0), dtype=float)
    return (
        cube,
        modimpls_evals_poids,
        modimpls_results,
    )


def compute_ue_moys_apc(
    sem_cube: np.array,
    etuds: list,
    modimpls: list,
    ues: list,
    modimpl_inscr_df: pd.DataFrame,
    modimpl_coefs_df: pd.DataFrame,
    modimpl_mask: np.array,
) -> pd.DataFrame:
    """Calcul de la moyenne d'UE en mode APC (BUT).
    La moyenne d'UE est un nombre (note/20), ou NI ou NA ou ERR
        NI non inscrit à (au moins un) module de cette UE
        NA pas de notes disponibles
        ERR erreur dans une formule utilisateurs (pas gérées ici).

    sem_cube: notes moyennes aux modules
                ndarray (etuds x modimpls x UEs)
                (floats avec des NaN)
    etuds : liste des étudiants (dim. 0 du cube)
    modimpls : liste des module_impl (dim. 1 du cube)
    ues : liste des UE (dim. 2 du cube)
    modimpl_inscr_df: matrice d'inscription du semestre (etud x modimpl)
    modimpl_coefs_df: matrice coefficients (UE x modimpl), sans UEs bonus sport
    modimpl_mask: liste de booléens, indiquants le module doit être pris ou pas.
                    (utilisé pour éliminer les bonus, et pourra servir à cacluler
                    sur des sous-ensembles de modules)

    Résultat: DataFrame columns UE (sans bonus), rows etudid
    """
    nb_etuds, nb_modules, nb_ues_no_bonus = sem_cube.shape
    nb_ues_tot = len(ues)
    assert len(modimpls) == nb_modules
    if nb_modules == 0 or nb_etuds == 0 or nb_ues_no_bonus == 0:
        return pd.DataFrame(
            index=modimpl_inscr_df.index, columns=modimpl_coefs_df.index
        )
    assert len(etuds) == nb_etuds
    assert modimpl_inscr_df.shape[0] == nb_etuds
    assert modimpl_inscr_df.shape[1] == nb_modules
    assert modimpl_coefs_df.shape[0] == nb_ues_no_bonus
    assert modimpl_coefs_df.shape[1] == nb_modules
    modimpl_inscr = modimpl_inscr_df.values
    # Met à zéro tous les coefs des modules non sélectionnés dans le masque:
    modimpl_coefs = np.where(modimpl_mask, modimpl_coefs_df.values, 0.0)

    # Duplique les inscriptions sur les UEs non bonus:
    modimpl_inscr_stacked = np.stack([modimpl_inscr] * nb_ues_no_bonus, axis=2)
    # Enlève les NaN du numérateur:
    # si on veut prendre en compte les modules avec notes neutralisées ?
    sem_cube_no_nan = np.nan_to_num(sem_cube, nan=0.0)

    # Ne prend pas en compte les notes des étudiants non inscrits au module:
    # Annule les notes:
    sem_cube_inscrits = np.where(modimpl_inscr_stacked, sem_cube_no_nan, 0.0)
    # Annule les coefs des modules où l'étudiant n'est pas inscrit:
    modimpl_coefs_etuds = np.where(
        modimpl_inscr_stacked, np.stack([modimpl_coefs.T] * nb_etuds), 0.0
    )
    # Annule les coefs des modules NaN
    modimpl_coefs_etuds_no_nan = np.where(np.isnan(sem_cube), 0.0, modimpl_coefs_etuds)
    if modimpl_coefs_etuds_no_nan.dtype == np.object:  # arrive sur des tableaux vides
        modimpl_coefs_etuds_no_nan = modimpl_coefs_etuds_no_nan.astype(np.float)
    #
    # Version vectorisée
    #
    with np.errstate(invalid="ignore"):  # ignore les 0/0 (-> NaN)
        etud_moy_ue = np.sum(
            modimpl_coefs_etuds_no_nan * sem_cube_inscrits, axis=1
        ) / np.sum(modimpl_coefs_etuds_no_nan, axis=1)
    return pd.DataFrame(
        etud_moy_ue,
        index=modimpl_inscr_df.index,  # les etudids
        columns=modimpl_coefs_df.index,  # les UE sans les UE bonus sport
    )


def compute_ue_moys_classic(
    formsemestre: FormSemestre,
    sem_matrix: np.array,
    ues: list,
    modimpl_inscr_df: pd.DataFrame,
    modimpl_coefs: np.array,
    modimpl_mask: np.array,
) -> tuple[pd.Series, pd.DataFrame, pd.DataFrame]:
    """Calcul de la moyenne d'UE et de la moy. générale en mode classique (DUT, LMD, ...).

    La moyenne d'UE est un nombre (note/20), ou NI ou NA ou ERR
        NI non inscrit à (au moins un) module de cette UE
        NA pas de notes disponibles
        ERR erreur dans une formule utilisateur. [XXX pas encore gérées ici]

    L'éventuel bonus sport n'est PAS appliqué ici.

    Le masque modimpl_mask est un tableau de booléens (un par modimpl) qui
    permet de sélectionner un sous-ensemble de modules (SAEs, tout sauf sport, ...).

    sem_matrix: notes moyennes aux modules (tous les étuds x tous les modimpls)
                ndarray (etuds x modimpls)
                (floats avec des NaN)
    etuds : listes des étudiants (dim. 0 de la matrice)
    ues : liste des UE du semestre
    modimpl_inscr_df: matrice d'inscription du semestre (etud x modimpl)
    modimpl_coefs: vecteur des coefficients de modules
    modimpl_mask: masque des modimpls à prendre en compte

    Résultat:
     - moyennes générales: pd.Series, index etudid
     - moyennes d'UE: DataFrame columns UE, rows etudid
     - coefficients d'UE: DataFrame, columns UE, rows etudid
        les coefficients effectifs de chaque UE pour chaque étudiant
        (sommes de coefs de modules pris en compte)
    """
    if (not len(modimpl_mask)) or (
        sem_matrix.shape[0] == 0
    ):  # aucun module ou aucun étudiant
        # etud_moy_gen_s, etud_moy_ue_df, etud_coef_ue_df
        return (
            pd.Series(
                [0.0] * len(modimpl_inscr_df.index), index=modimpl_inscr_df.index
            ),
            pd.DataFrame(columns=[ue.id for ue in ues], index=modimpl_inscr_df.index),
            pd.DataFrame(columns=[ue.id for ue in ues], index=modimpl_inscr_df.index),
        )
    # Restreint aux modules sélectionnés:
    sem_matrix = sem_matrix[:, modimpl_mask]
    modimpl_inscr = modimpl_inscr_df.values[:, modimpl_mask]
    modimpl_coefs = modimpl_coefs[modimpl_mask]

    nb_etuds, nb_modules = sem_matrix.shape
    assert len(modimpl_coefs) == nb_modules
    nb_ues = len(ues)  # en comptant bonus

    # Enlève les NaN du numérateur:
    sem_matrix_no_nan = np.nan_to_num(sem_matrix, nan=0.0)
    # Ne prend pas en compte les notes des étudiants non inscrits au module:
    # Annule les notes:
    sem_matrix_inscrits = np.where(modimpl_inscr, sem_matrix_no_nan, 0.0)
    # Annule les coefs des modules où l'étudiant n'est pas inscrit:
    modimpl_coefs_etuds = np.where(
        modimpl_inscr, np.stack([modimpl_coefs.T] * nb_etuds), 0.0
    )
    # Annule les coefs des modules NaN (nb_etuds x nb_mods)
    modimpl_coefs_etuds_no_nan = np.where(
        np.isnan(sem_matrix), 0.0, modimpl_coefs_etuds
    )
    if modimpl_coefs_etuds_no_nan.dtype == np.object:  # arrive sur des tableaux vides
        modimpl_coefs_etuds_no_nan = modimpl_coefs_etuds_no_nan.astype(np.float)
    # ---------------------  Calcul des moyennes d'UE
    ue_modules = np.array(
        [[m.module.ue == ue for m in formsemestre.modimpls_sorted] for ue in ues]
    )[..., np.newaxis][:, modimpl_mask, :]
    modimpl_coefs_etuds_no_nan_stacked = np.stack(
        [modimpl_coefs_etuds_no_nan.T] * nb_ues
    )
    # nb_ue x nb_etuds x nb_mods : coefs prenant en compte NaN et inscriptions:
    coefs = (modimpl_coefs_etuds_no_nan_stacked * ue_modules).swapaxes(1, 2)
    if coefs.dtype == np.object:  # arrive sur des tableaux vides
        coefs = coefs.astype(np.float)
    with np.errstate(invalid="ignore"):  # ignore les 0/0 (-> NaN)
        etud_moy_ue = (
            np.sum(coefs * sem_matrix_inscrits, axis=2) / np.sum(coefs, axis=2)
        ).T
    etud_moy_ue_df = pd.DataFrame(
        etud_moy_ue, index=modimpl_inscr_df.index, columns=[ue.id for ue in ues]
    )

    #  ---------------------  Calcul des moyennes générales
    if sco_preferences.get_preference("use_ue_coefs", formsemestre.id):
        # Cas avec coefficients d'UE forcés: (on met à zéro l'UE bonus)
        etud_coef_ue_df = pd.DataFrame(
            {ue.id: ue.coefficient if ue.type != UE_SPORT else 0.0 for ue in ues},
            index=modimpl_inscr_df.index,
            columns=[ue.id for ue in ues],
        )
        # remplace NaN par zéros dans les moyennes d'UE
        etud_moy_ue_df_no_nan = etud_moy_ue_df.fillna(0.0, inplace=False)
        # Si on voulait annuler les coef d'UE dont la moyenne d'UE est NaN
        #     etud_coef_ue_df_no_nan = etud_coef_ue_df.where(etud_moy_ue_df.notna(), 0.0)
        with np.errstate(invalid="ignore"):  # ignore les 0/0 (-> NaN)
            etud_moy_gen_s = (etud_coef_ue_df * etud_moy_ue_df_no_nan).sum(
                axis=1
            ) / etud_coef_ue_df.sum(axis=1)
    else:
        # Cas normal: pondère directement les modules
        etud_coef_ue_df = pd.DataFrame(
            coefs.sum(axis=2).T,
            index=modimpl_inscr_df.index,  # etudids
            columns=[ue.id for ue in ues],
        )
        with np.errstate(invalid="ignore"):  # ignore les 0/0 (-> NaN)
            etud_moy_gen = np.sum(
                modimpl_coefs_etuds_no_nan * sem_matrix_inscrits, axis=1
            ) / np.sum(modimpl_coefs_etuds_no_nan, axis=1)

            etud_moy_gen_s = pd.Series(etud_moy_gen, index=modimpl_inscr_df.index)

    return etud_moy_gen_s, etud_moy_ue_df, etud_coef_ue_df


def compute_mat_moys_classic(
    sem_matrix: np.array,
    modimpl_inscr_df: pd.DataFrame,
    modimpl_coefs: np.array,
    modimpl_mask: np.array,
) -> pd.Series:
    """Calcul de la moyenne sur un sous-enemble de modules en formation CLASSIQUE

    La moyenne est un nombre (note/20 ou NaN.

    Le masque modimpl_mask est un tableau de booléens (un par modimpl) qui
    permet de sélectionner un sous-ensemble de modules (ceux de la matière d'intérêt).

    sem_matrix: notes moyennes aux modules (tous les étuds x tous les modimpls)
                ndarray (etuds x modimpls)
                (floats avec des NaN)
    etuds : listes des étudiants (dim. 0 de la matrice)
    modimpl_inscr_df: matrice d'inscription du semestre (etud x modimpl)
    modimpl_coefs: vecteur des coefficients de modules
    modimpl_mask: masque des modimpls à prendre en compte

    Résultat:
     - moyennes: pd.Series, index etudid
    """
    if (not len(modimpl_mask)) or (
        sem_matrix.shape[0] == 0
    ):  # aucun module ou aucun étudiant
        # etud_moy_gen_s, etud_moy_ue_df, etud_coef_ue_df
        return pd.Series(
            [0.0] * len(modimpl_inscr_df.index), index=modimpl_inscr_df.index
        )
    # Restreint aux modules sélectionnés:
    sem_matrix = sem_matrix[:, modimpl_mask]
    modimpl_inscr = modimpl_inscr_df.values[:, modimpl_mask]
    modimpl_coefs = modimpl_coefs[modimpl_mask]

    nb_etuds, nb_modules = sem_matrix.shape
    assert len(modimpl_coefs) == nb_modules

    # Enlève les NaN du numérateur:
    sem_matrix_no_nan = np.nan_to_num(sem_matrix, nan=0.0)
    # Ne prend pas en compte les notes des étudiants non inscrits au module:
    # Annule les notes:
    sem_matrix_inscrits = np.where(modimpl_inscr, sem_matrix_no_nan, 0.0)
    # Annule les coefs des modules où l'étudiant n'est pas inscrit:
    modimpl_coefs_etuds = np.where(
        modimpl_inscr, np.stack([modimpl_coefs.T] * nb_etuds), 0.0
    )
    # Annule les coefs des modules NaN (nb_etuds x nb_mods)
    modimpl_coefs_etuds_no_nan = np.where(
        np.isnan(sem_matrix), 0.0, modimpl_coefs_etuds
    )
    if modimpl_coefs_etuds_no_nan.dtype == np.object:  # arrive sur des tableaux vides
        modimpl_coefs_etuds_no_nan = modimpl_coefs_etuds_no_nan.astype(np.float)

    etud_moy_mat = (modimpl_coefs_etuds_no_nan * sem_matrix_inscrits).sum(
        axis=1
    ) / modimpl_coefs_etuds_no_nan.sum(axis=1)

    return pd.Series(etud_moy_mat, index=modimpl_inscr_df.index)


def compute_malus(
    formsemestre: FormSemestre,
    sem_modimpl_moys: np.array,
    ues: list[UniteEns],
    modimpl_inscr_df: pd.DataFrame,
) -> pd.DataFrame:
    """Calcul le malus sur les UE
    Dans chaque UE, on peut avoir un  ou plusieurs modules de MALUS.
    Leurs notes sont positives ou négatives.
    La somme des notes de malus somme est _soustraite_ à la moyenne de chaque UE.

    Arguments:
        - sem_modimpl_moys :
            notes moyennes aux modules (tous les étuds x tous les modimpls)
            floats avec des NaN.
            En classique: sem_matrix, ndarray (etuds x modimpls)
            En APC: sem_cube, ndarray (etuds x modimpls x UEs non bonus)
        - ues: les ues du semestre (incluant le bonus sport)
        - modimpl_inscr_df: matrice d'inscription aux modules du semestre (etud x modimpl)

    Résultat: DataFrame de float, index etudid, columns: ue.id (sans NaN)
    """
    ues_idx = [ue.id for ue in ues]
    malus = pd.DataFrame(index=modimpl_inscr_df.index, columns=ues_idx, dtype=float)
    for ue in ues:
        if ue.type != UE_SPORT:
            modimpl_mask = np.array(
                [
                    (m.module.module_type == ModuleType.MALUS)
                    and (m.module.ue.id == ue.id)
                    for m in formsemestre.modimpls_sorted
                ]
            )
            if len(modimpl_mask):
                malus_moys = sem_modimpl_moys[:, modimpl_mask].sum(axis=1)
                malus[ue.id] = malus_moys

    malus.fillna(0.0, inplace=True)
    return malus