# -*- mode: python -*-
# -*- coding: utf-8 -*-

##############################################################################
#
# Gestion scolarite IUT
#
# Copyright (c) 1999 - 2022 Emmanuel Viennet.  All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#   Emmanuel Viennet      emmanuel.viennet@viennet.net
#
##############################################################################

"""Fonctions de calcul des moyennes d'UE (classiques ou BUT)
"""
from re import X
import numpy as np
import pandas as pd

from app import db
from app import models
from app.models import UniteEns, Module, ModuleImpl, ModuleUECoef
from app.comp import moy_mod
from app.models.formsemestre import FormSemestre
from app.scodoc import sco_codes_parcours
from app.scodoc import sco_preferences
from app.scodoc.sco_codes_parcours import UE_SPORT
from app.scodoc.sco_utils import ModuleType


def df_load_module_coefs(formation_id: int, semestre_idx: int = None) -> pd.DataFrame:
    """Charge les coefs APC des modules de la formation pour le semestre indiqué.

    En APC, ces coefs lient les modules à chaque UE.

    Résultat: (module_coefs_df, ues_no_bonus, modules)
        DataFrame rows = UEs, columns = modules, value = coef.

    Considère toutes les UE sauf bonus et tous les modules du semestre.
    Les coefs non définis (pas en base) sont mis à zéro.

    Si semestre_idx None, prend toutes les UE de la formation.
    """
    ues = (
        UniteEns.query.filter_by(formation_id=formation_id)
        .filter(UniteEns.type != sco_codes_parcours.UE_SPORT)
        .order_by(UniteEns.semestre_idx, UniteEns.numero, UniteEns.acronyme)
    )
    modules = (
        Module.query.filter_by(formation_id=formation_id)
        .filter(
            (Module.module_type == ModuleType.RESSOURCE)
            | (Module.module_type == ModuleType.SAE)
            | (
                (Module.ue_id == UniteEns.id)
                & (UniteEns.type == sco_codes_parcours.UE_SPORT)
            )
        )
        .order_by(
            Module.semestre_id, Module.module_type.desc(), Module.numero, Module.code
        )
    )
    if semestre_idx is not None:
        ues = ues.filter_by(semestre_idx=semestre_idx)
        modules = modules.filter_by(semestre_id=semestre_idx)
    ues = ues.all()
    modules = modules.all()
    ue_ids = [ue.id for ue in ues]
    module_ids = [module.id for module in modules]
    module_coefs_df = pd.DataFrame(columns=module_ids, index=ue_ids, dtype=float)
    query = (
        db.session.query(ModuleUECoef)
        .filter(UniteEns.formation_id == formation_id)
        .filter(ModuleUECoef.ue_id == UniteEns.id)
    )
    if semestre_idx is not None:
        query = query.filter(UniteEns.semestre_idx == semestre_idx)

    for mod_coef in query:
        if mod_coef.module_id in module_coefs_df:
            module_coefs_df[mod_coef.module_id][mod_coef.ue_id] = mod_coef.coef
        # silently ignore coefs associated to other modules (ie when module_type is changed)

    # Initialisation des poids non fixés:
    # 0 pour modules normaux, 1. pour bonus (car par défaut, on veut qu'un bonus agisse
    # sur toutes les UE)
    default_poids = {
        mod.id: 1.0
        if (mod.module_type == ModuleType.STANDARD) and (mod.ue.type == UE_SPORT)
        else 0.0
        for mod in modules
    }

    module_coefs_df.fillna(value=default_poids, inplace=True)

    return module_coefs_df, ues, modules


def df_load_modimpl_coefs(
    formsemestre: models.FormSemestre, ues=None, modimpls=None
) -> pd.DataFrame:
    """Charge les coefs APC des modules du formsemestre indiqué.

    Comme df_load_module_coefs mais prend seulement les UE
    et modules du formsemestre.
    Si ues et modimpls sont None, prend tous ceux du formsemestre (sauf ue bonus).
    Résultat: (module_coefs_df, ues, modules)
        DataFrame rows = UEs (sans bonus), columns = modimpl, value = coef.
    """
    if ues is None:
        ues = formsemestre.query_ues().all()
    ue_ids = [x.id for x in ues]
    if modimpls is None:
        modimpls = formsemestre.modimpls_sorted
    modimpl_ids = [x.id for x in modimpls]
    mod2impl = {m.module.id: m.id for m in modimpls}
    modimpl_coefs_df = pd.DataFrame(columns=modimpl_ids, index=ue_ids, dtype=float)
    mod_coefs = (
        db.session.query(ModuleUECoef)
        .filter(ModuleUECoef.module_id == ModuleImpl.module_id)
        .filter(ModuleImpl.formsemestre_id == formsemestre.id)
    )

    for mod_coef in mod_coefs:
        modimpl_coefs_df[mod2impl[mod_coef.module_id]][mod_coef.ue_id] = mod_coef.coef

    # Initialisation des poids non fixés:
    # 0 pour modules normaux, 1. pour bonus (car par défaut, on veut qu'un bonus agisse
    # sur toutes les UE)
    default_poids = {
        modimpl.id: 1.0
        if (modimpl.module.module_type == ModuleType.STANDARD)
        and (modimpl.module.ue.type == UE_SPORT)
        else 0.0
        for modimpl in formsemestre.modimpls_sorted
    }

    modimpl_coefs_df.fillna(value=default_poids, inplace=True)
    return modimpl_coefs_df, ues, modimpls


def notes_sem_assemble_cube(modimpls_notes: list[pd.DataFrame]) -> np.ndarray:
    """Réuni les notes moyennes des modules du semestre en un "cube"

    modimpls_notes : liste des moyennes de module
                     (DataFrames rendus par compute_module_moy, (etud x UE))
    Resultat: ndarray (etud x module x UE)
    """
    assert len(modimpls_notes)
    modimpls_notes_arr = [df.values for df in modimpls_notes]
    modimpls_notes = np.stack(modimpls_notes_arr)
    # passe de (mod x etud x ue) à (etud x mod x ue)
    return modimpls_notes.swapaxes(0, 1)


def notes_sem_load_cube(formsemestre: FormSemestre) -> tuple:
    """Construit le "cube" (tenseur) des notes du semestre.
    Charge toutes les notes (sql), calcule les moyennes des modules
    et assemble le cube.

    etuds: tous les inscrits au semestre (avec dem. et def.)
    modimpls: _tous_ les modimpls de ce semestre (y compris bonus sport)
    UEs: toutes les UE du semestre (même si pas d'inscrits) SAUF le sport.

    Attention: la liste des modimpls inclut les modules des UE sport, mais
    elles ne sont pas dans la troisième dimension car elles n'ont pas de
    "moyenne d'UE".

    Résultat:
        sem_cube : ndarray (etuds x modimpls x UEs)
        modimpls_evals_poids dict { modimpl.id : evals_poids }
        modimpls_results dict { modimpl.id : ModuleImplResultsAPC }
    """
    modimpls_results = {}
    modimpls_evals_poids = {}
    modimpls_notes = []
    for modimpl in formsemestre.modimpls_sorted:
        mod_results = moy_mod.ModuleImplResultsAPC(modimpl)
        evals_poids, _ = moy_mod.load_evaluations_poids(modimpl.id)
        etuds_moy_module = mod_results.compute_module_moy(evals_poids)
        modimpls_results[modimpl.id] = mod_results
        modimpls_notes.append(etuds_moy_module)
    if len(modimpls_notes):
        cube = notes_sem_assemble_cube(modimpls_notes)
    else:
        nb_etuds = formsemestre.etuds.count()
        cube = np.zeros((nb_etuds, 0, 0), dtype=float)
    return (
        cube,
        modimpls_evals_poids,
        modimpls_results,
    )


def compute_ue_moys_apc(
    sem_cube: np.array,
    etuds: list,
    modimpls: list,
    ues: list,
    modimpl_inscr_df: pd.DataFrame,
    modimpl_coefs_df: pd.DataFrame,
) -> pd.DataFrame:
    """Calcul de la moyenne d'UE en mode APC (BUT).
    La moyenne d'UE est un nombre (note/20), ou NI ou NA ou ERR
        NI non inscrit à (au moins un) module de cette UE
        NA pas de notes disponibles
        ERR erreur dans une formule utilisateur. [XXX pas encore gérées ici]

    sem_cube: notes moyennes aux modules
                ndarray (etuds x modimpls x UEs)
                (floats avec des NaN)
    etuds : liste des étudiants (dim. 0 du cube)
    modimpls : liste des modules à considérer (dim. 1 du cube)
    ues : liste des UE (dim. 2 du cube)
    modimpl_inscr_df: matrice d'inscription du semestre (etud x modimpl)
    modimpl_coefs_df: matrice coefficients (UE x modimpl), sans UEs bonus sport

    Résultat: DataFrame columns UE (sans sport), rows etudid
    """
    nb_etuds, nb_modules, nb_ues_no_bonus = sem_cube.shape
    nb_ues_tot = len(ues)
    assert len(modimpls) == nb_modules
    if nb_modules == 0 or nb_etuds == 0:
        return pd.DataFrame(
            index=modimpl_inscr_df.index, columns=modimpl_coefs_df.index
        )
    assert len(etuds) == nb_etuds
    assert modimpl_inscr_df.shape[0] == nb_etuds
    assert modimpl_inscr_df.shape[1] == nb_modules
    assert modimpl_coefs_df.shape[0] == nb_ues_no_bonus
    assert modimpl_coefs_df.shape[1] == nb_modules
    modimpl_inscr = modimpl_inscr_df.values
    modimpl_coefs = modimpl_coefs_df.values

    # Duplique les inscriptions sur les UEs non bonus:
    modimpl_inscr_stacked = np.stack([modimpl_inscr] * nb_ues_no_bonus, axis=2)
    # Enlève les NaN du numérateur:
    # si on veut prendre en compte les modules avec notes neutralisées ?
    sem_cube_no_nan = np.nan_to_num(sem_cube, nan=0.0)

    # Ne prend pas en compte les notes des étudiants non inscrits au module:
    # Annule les notes:
    sem_cube_inscrits = np.where(modimpl_inscr_stacked, sem_cube_no_nan, 0.0)
    # Annule les coefs des modules où l'étudiant n'est pas inscrit:
    modimpl_coefs_etuds = np.where(
        modimpl_inscr_stacked, np.stack([modimpl_coefs.T] * nb_etuds), 0.0
    )
    # Annule les coefs des modules NaN
    modimpl_coefs_etuds_no_nan = np.where(np.isnan(sem_cube), 0.0, modimpl_coefs_etuds)
    #
    # Version vectorisée
    #
    with np.errstate(invalid="ignore"):  # ignore les 0/0 (-> NaN)
        etud_moy_ue = np.sum(
            modimpl_coefs_etuds_no_nan * sem_cube_inscrits, axis=1
        ) / np.sum(modimpl_coefs_etuds_no_nan, axis=1)
    return pd.DataFrame(
        etud_moy_ue,
        index=modimpl_inscr_df.index,  # les etudids
        columns=modimpl_coefs_df.index,  # les UE sans les UE bonus sport
    )


def compute_ue_moys_classic(
    formsemestre: FormSemestre,
    sem_matrix: np.array,
    ues: list,
    modimpl_inscr_df: pd.DataFrame,
    modimpl_coefs: np.array,
    modimpl_mask: np.array,
) -> tuple[pd.Series, pd.DataFrame, pd.DataFrame]:
    """Calcul de la moyenne d'UE en mode classique.
    La moyenne d'UE est un nombre (note/20), ou NI ou NA ou ERR
        NI non inscrit à (au moins un) module de cette UE
        NA pas de notes disponibles
        ERR erreur dans une formule utilisateur. [XXX pas encore gérées ici]

    L'éventuel bonus sport n'est PAS appliqué ici.

    Le masque modimpl_mask est un tableau de booléens (un par modimpl) qui
    permet de sélectionner un sous-ensemble de modules (SAEs, tout sauf sport, ...).

    sem_matrix: notes moyennes aux modules (tous les étuds x tous les modimpls)
                ndarray (etuds x modimpls)
                (floats avec des NaN)
    etuds : listes des étudiants (dim. 0 de la matrice)
    ues : liste des UE du semestre
    modimpl_inscr_df: matrice d'inscription du semestre (etud x modimpl)
    modimpl_coefs: vecteur des coefficients de modules
    modimpl_mask: masque des modimpls à prendre en compte

    Résultat:
     - moyennes générales: pd.Series, index etudid
     - moyennes d'UE: DataFrame columns UE, rows etudid
     - coefficients d'UE: DataFrame, columns UE, rows etudid
        les coefficients effectifs de chaque UE pour chaque étudiant
        (sommes de coefs de modules pris en compte)
    """
    # Restreint aux modules sélectionnés:
    sem_matrix = sem_matrix[:, modimpl_mask]
    modimpl_inscr = modimpl_inscr_df.values[:, modimpl_mask]
    modimpl_coefs = modimpl_coefs[modimpl_mask]

    nb_etuds, nb_modules = sem_matrix.shape
    assert len(modimpl_coefs) == nb_modules
    nb_ues = len(ues)  # en comptant bonus

    # Enlève les NaN du numérateur:
    sem_matrix_no_nan = np.nan_to_num(sem_matrix, nan=0.0)
    # Ne prend pas en compte les notes des étudiants non inscrits au module:
    # Annule les notes:
    sem_matrix_inscrits = np.where(modimpl_inscr, sem_matrix_no_nan, 0.0)
    # Annule les coefs des modules où l'étudiant n'est pas inscrit:
    modimpl_coefs_etuds = np.where(
        modimpl_inscr, np.stack([modimpl_coefs.T] * nb_etuds), 0.0
    )
    # Annule les coefs des modules NaN (nb_etuds x nb_mods)
    modimpl_coefs_etuds_no_nan = np.where(
        np.isnan(sem_matrix), 0.0, modimpl_coefs_etuds
    )

    # ---------------------  Calcul des moyennes d'UE
    ue_modules = np.array(
        [[m.module.ue == ue for m in formsemestre.modimpls_sorted] for ue in ues]
    )[..., np.newaxis][:, modimpl_mask, :]
    modimpl_coefs_etuds_no_nan_stacked = np.stack(
        [modimpl_coefs_etuds_no_nan.T] * nb_ues
    )
    # nb_ue x nb_etuds x nb_mods : coefs prenant en compte NaN et inscriptions
    coefs = (modimpl_coefs_etuds_no_nan_stacked * ue_modules).swapaxes(1, 2)
    with np.errstate(invalid="ignore"):  # ignore les 0/0 (-> NaN)
        etud_moy_ue = (
            np.sum(coefs * sem_matrix_inscrits, axis=2) / np.sum(coefs, axis=2)
        ).T
    etud_moy_ue_df = pd.DataFrame(
        etud_moy_ue, index=modimpl_inscr_df.index, columns=[ue.id for ue in ues]
    )

    #  ---------------------  Calcul des moyennes générales
    if sco_preferences.get_preference("use_ue_coefs", formsemestre.id):
        # Cas avec coefficients d'UE forcés: (on met à zéro l'UE bonus)
        etud_coef_ue_df = pd.DataFrame(
            {ue.id: ue.coefficient if ue.type != UE_SPORT else 0.0 for ue in ues},
            index=modimpl_inscr_df.index,
            columns=[ue.id for ue in ues],
        )
        # remplace NaN par zéros dans les moyennes d'UE
        etud_moy_ue_df_no_nan = etud_moy_ue_df.fillna(0.0, inplace=False)
        # Si on voulait annuler les coef d'UE dont la moyenne d'UE est NaN
        #     etud_coef_ue_df_no_nan = etud_coef_ue_df.where(etud_moy_ue_df.notna(), 0.0)
        with np.errstate(invalid="ignore"):  # ignore les 0/0 (-> NaN)
            etud_moy_gen_s = (etud_coef_ue_df * etud_moy_ue_df_no_nan).sum(
                axis=1
            ) / etud_coef_ue_df.sum(axis=1)
    else:
        # Cas normal: pondère directement les modules
        etud_coef_ue_df = pd.DataFrame(
            coefs.sum(axis=2).T,
            index=modimpl_inscr_df.index,  # etudids
            columns=[ue.id for ue in ues],
        )
        with np.errstate(invalid="ignore"):  # ignore les 0/0 (-> NaN)
            etud_moy_gen = np.sum(
                modimpl_coefs_etuds_no_nan * sem_matrix_inscrits, axis=1
            ) / np.sum(modimpl_coefs_etuds_no_nan, axis=1)

            etud_moy_gen_s = pd.Series(etud_moy_gen, index=modimpl_inscr_df.index)

    return etud_moy_gen_s, etud_moy_ue_df, etud_coef_ue_df


def compute_malus(
    formsemestre: FormSemestre,
    sem_modimpl_moys: np.array,
    ues: list[UniteEns],
    modimpl_inscr_df: pd.DataFrame,
) -> pd.DataFrame:
    """Calcul le malus sur les UE
    Dans chaque UE, on peut avoir un  ou plusieurs modules de MALUS.
    Leurs notes sont positives ou négatives. leur somme sera _soustraite_ à la moyenne
    de chaque UE.
    Arguments:
        - sem_modimpl_moys :
            notes moyennes aux modules (tous les étuds x tous les modimpls)
            floats avec des NaN.
            En classique: sem_matrix, ndarray (etuds x modimpls)
            En APC: sem_cube, ndarray (etuds x modimpls x UEs non bonus)
        - ues: les ues du semestre (incluant le bonus sport)
        - modimpl_inscr_df: matrice d'inscription aux modules du semestre (etud x modimpl)

    Résultat: DataFrame de float, index etudid, columns: ue.id (sans NaN)
    """
    ues_idx = [ue.id for ue in ues]
    malus = pd.DataFrame(index=modimpl_inscr_df.index, columns=ues_idx, dtype=float)
    for ue in ues:
        if ue.type != UE_SPORT:
            modimpl_mask = np.array(
                [
                    (m.module.module_type == ModuleType.MALUS)
                    and (m.module.ue.id == ue.id)
                    for m in formsemestre.modimpls_sorted
                ]
            )
            malus_moys = sem_modimpl_moys[:, modimpl_mask].sum(axis=1)
            malus[ue.id] = malus_moys

    malus.fillna(0.0, inplace=True)
    return malus