# -*- mode: python -*-
# -*- coding: utf-8 -*-

##############################################################################
#
# Gestion scolarite IUT
#
# Copyright (c) 1999 - 2022 Emmanuel Viennet.  All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#   Emmanuel Viennet      emmanuel.viennet@viennet.net
#
##############################################################################

"""Fonctions de calcul des moyennes de modules (modules, ressources ou SAÉ)

Pour les formations classiques et le BUT

Rappel: pour éviter les confusions, on appelera *poids* les coefficients d'une
évaluation dans un module, et *coefficients* ceux utilisés pour le calcul de la
moyenne générale d'une UE.
"""
from dataclasses import dataclass
import numpy as np
import pandas as pd

from app import db
from app.models import ModuleImpl, Evaluation, EvaluationUEPoids
from app.scodoc import sco_utils as scu
from app.scodoc.sco_codes_parcours import UE_SPORT
from app.scodoc.sco_exceptions import ScoValueError


@dataclass
class EvaluationEtat:
    """Classe pour stocker quelques infos sur les résultats d'une évaluation"""

    evaluation_id: int
    nb_attente: int
    is_complete: bool


class ModuleImplResults:
    """Classe commune à toutes les formations (standard et APC).
    Les notes des étudiants d'un moduleimpl.
    Les poids des évals sont à part car on en a besoin sans les notes pour les
    tableaux de bord.
    Les attributs sont tous des objets simples cachables dans Redis;
    les caches sont gérés par  ResultatsSemestre.
    """

    def __init__(self, moduleimpl: ModuleImpl):
        self.moduleimpl_id = moduleimpl.id
        self.module_id = moduleimpl.module.id
        self.etudids = None
        "liste des étudiants inscrits au SEMESTRE"

        self.nb_inscrits_module = None
        "nombre d'inscrits (non DEM) à ce module"
        self.evaluations_completes = []
        "séquence de booléens, indiquant les évals à prendre en compte."
        self.evaluations_completes_dict = {}
        "{ evaluation.id : bool } indique si à prendre en compte ou non."
        self.evaluations_etat = {}
        "{ evaluation_id: EvaluationEtat }"
        #
        self.evals_notes = None
        """DataFrame, colonnes: EVALS, Lignes: etudid (inscrits au SEMESTRE)
            valeur: notes brutes, float ou NOTES_ATTENTE, NOTES_NEUTRALISE,
            NOTES_ABSENCE.
            Les NaN désignent les notes manquantes (non saisies).
        """
        self.etuds_moy_module = None
        """DataFrame, colonnes UE, lignes etud
            = la note de l'étudiant dans chaque UE pour ce module.
            ou NaN si les évaluations (dans lesquelles l'étudiant a des notes)
            ne donnent pas de coef vers cette UE.
        """
        self.load_notes()

    def load_notes(self):  # ré-écriture de df_load_modimpl_notes
        """Charge toutes les notes de toutes les évaluations du module.
        Dataframe evals_notes
            colonnes: le nom de la colonne est l'evaluation_id (int)
            index (lignes): etudid (int)

        L'ensemble des étudiants est celui des inscrits au SEMESTRE.

        Les notes sont "brutes" (séries de floats) et peuvent prendre les valeurs:
            note : float (valeur enregistrée brute, NON normalisée sur 20)
            pas de note: NaN (rien en bd, ou étudiant non inscrit au module)
            absent: NOTES_ABSENCE (NULL en bd)
            excusé: NOTES_NEUTRALISE (voir sco_utils)
            attente: NOTES_ATTENTE

        Évaluation "complete" (prise en compte dans les calculs) si:
        - soit tous les étudiants inscrits au module ont des notes
        - soit elle a été déclarée "à prise en compte immédiate" (publish_incomplete)

        Évaluation "attente" (prise en compte dans les calculs, mais il y
        manque des notes) ssi il y a des étudiants inscrits au semestre et au module
        qui ont des notes ATT.
        """
        moduleimpl = ModuleImpl.query.get(self.moduleimpl_id)
        self.etudids = self._etudids()

        # --- Calcul nombre d'inscrits pour déterminer les évaluations "completes":
        # on prend les inscrits au module ET au semestre (donc sans démissionnaires)
        inscrits_module = {ins.etud.id for ins in moduleimpl.inscriptions}.intersection(
            self.etudids
        )
        self.nb_inscrits_module = len(inscrits_module)

        # dataFrame vide, index = tous les inscrits au SEMESTRE
        evals_notes = pd.DataFrame(index=self.etudids, dtype=float)
        self.evaluations_completes = []
        self.evaluations_completes_dict = {}
        for evaluation in moduleimpl.evaluations:
            eval_df = self._load_evaluation_notes(evaluation)
            # is_complete ssi tous les inscrits (non dem) au semestre ont une note
            # ou évaluaton déclarée "à prise en compte immédiate"
            is_complete = (
                len(set(eval_df.index).intersection(self.etudids))
                == self.nb_inscrits_module
            ) or evaluation.publish_incomplete  # immédiate
            self.evaluations_completes.append(is_complete)
            self.evaluations_completes_dict[evaluation.id] = is_complete

            # NULL en base => ABS (= -999)
            eval_df.fillna(scu.NOTES_ABSENCE, inplace=True)
            # Ce merge ne garde que les étudiants inscrits au module
            # et met à NULL les notes non présentes
            #  (notes non saisies ou etuds non inscrits au module):
            evals_notes = evals_notes.merge(
                eval_df, how="left", left_index=True, right_index=True
            )
            # Notes en attente: (on prend dans evals_notes pour ne pas avoir les dem.)
            nb_att = sum(evals_notes[str(evaluation.id)] == scu.NOTES_ATTENTE)
            self.evaluations_etat[evaluation.id] = EvaluationEtat(
                evaluation_id=evaluation.id, nb_attente=nb_att, is_complete=is_complete
            )

        # Force columns names to integers (evaluation ids)
        evals_notes.columns = pd.Int64Index(
            [int(x) for x in evals_notes.columns], dtype="int"
        )
        self.evals_notes = evals_notes

    def _load_evaluation_notes(self, evaluation: Evaluation) -> pd.DataFrame:
        """Charge les notes de l'évaluation
        Resultat: dataframe, index: etudid ayant une note, valeur: note brute.
        """
        eval_df = pd.read_sql_query(
            """SELECT n.etudid, n.value AS "%(evaluation_id)s"
            FROM notes_notes n, notes_moduleimpl_inscription i
            WHERE evaluation_id=%(evaluation_id)s
            AND n.etudid = i.etudid
            AND i.moduleimpl_id = %(moduleimpl_id)s
            """,
            db.engine,
            params={
                "evaluation_id": evaluation.id,
                "moduleimpl_id": evaluation.moduleimpl.id,
            },
            index_col="etudid",
        )
        eval_df[str(evaluation.id)] = pd.to_numeric(eval_df[str(evaluation.id)])
        return eval_df

    def _etudids(self):
        """L'index du dataframe est la liste de tous les étudiants inscrits au semestre"""
        return [
            inscr.etudid
            for inscr in ModuleImpl.query.get(
                self.moduleimpl_id
            ).formsemestre.inscriptions
        ]

    def get_evaluations_coefs(self, moduleimpl: ModuleImpl) -> np.array:
        """Coefficients des évaluations, met à zéro ceux des évals incomplètes.
        Résultat: 2d-array of floats, shape (nb_evals, 1)
        """
        return (
            np.array(
                [e.coefficient for e in moduleimpl.evaluations],
                dtype=float,
            )
            * self.evaluations_completes
        ).reshape(-1, 1)

    def get_eval_notes_sur_20(self, moduleimpl: ModuleImpl) -> np.array:
        """Les notes des évaluations,
        remplace les  ATT, EXC, ABS, NaN par zéro et mets les notes sur 20.
        Résultat: 2d array of floats, shape nb_etuds x nb_evaluations
        """
        return np.where(
            self.evals_notes.values > scu.NOTES_ABSENCE, self.evals_notes.values, 0.0
        ) / [e.note_max / 20.0 for e in moduleimpl.evaluations]


class ModuleImplResultsAPC(ModuleImplResults):
    "Calcul des moyennes de modules à la mode BUT"

    def compute_module_moy(
        self,
        evals_poids_df: pd.DataFrame,
    ) -> pd.DataFrame:
        """Calcule les moyennes des étudiants dans ce module

        Argument: evals_poids: DataFrame, colonnes: UEs, Lignes: EVALs

        Résultat: DataFrame, colonnes UE, lignes etud
            = la note de l'étudiant dans chaque UE pour ce module.
            ou NaN si les évaluations (dans lesquelles l'étudiant a des notes)
            ne donnent pas de coef vers cette UE.
        """
        moduleimpl = ModuleImpl.query.get(self.moduleimpl_id)
        nb_etuds, nb_evals = self.evals_notes.shape
        nb_ues = evals_poids_df.shape[1]
        assert evals_poids_df.shape[0] == nb_evals  # compat notes/poids
        if nb_etuds == 0:
            return pd.DataFrame(index=[], columns=evals_poids_df.columns)
        evals_coefs = self.get_evaluations_coefs(moduleimpl)
        evals_poids = evals_poids_df.values * evals_coefs
        # -> evals_poids shape : (nb_evals, nb_ues)
        assert evals_poids.shape == (nb_evals, nb_ues)
        evals_notes_20 = self.get_eval_notes_sur_20(moduleimpl)

        # Les poids des évals pour chaque étudiant: là où il a des notes
        # non neutralisées
        # (ABS n'est  pas neutralisée, mais ATTENTE et NEUTRALISE oui)
        # Note: les NaN sont remplacés par des 0 dans evals_notes
        #  et dans dans evals_poids_etuds
        #  (rappel: la comparaison est toujours false face à un NaN)
        # shape: (nb_etuds, nb_evals, nb_ues)
        poids_stacked = np.stack([evals_poids] * nb_etuds)
        evals_poids_etuds = np.where(
            np.stack([self.evals_notes.values] * nb_ues, axis=2) > scu.NOTES_NEUTRALISE,
            poids_stacked,
            0,
        )
        # Calcule la moyenne pondérée sur les notes disponibles:
        evals_notes_stacked = np.stack([evals_notes_20] * nb_ues, axis=2)
        with np.errstate(invalid="ignore"):  # ignore les 0/0 (-> NaN)
            etuds_moy_module = np.sum(
                evals_poids_etuds * evals_notes_stacked, axis=1
            ) / np.sum(evals_poids_etuds, axis=1)
        self.etuds_moy_module = pd.DataFrame(
            etuds_moy_module,
            index=self.evals_notes.index,
            columns=evals_poids_df.columns,
        )
        return self.etuds_moy_module


def load_evaluations_poids(moduleimpl_id: int) -> tuple[pd.DataFrame, list]:
    """Charge poids des évaluations d'un module et retourne un dataframe
    rows = evaluations, columns = UE, value = poids (float).
    Les valeurs manquantes (évaluations sans coef vers des UE) sont
    remplies: 1 si le coef de ce module dans l'UE est non nul, zéro sinon
    (sauf pour module bonus, defaut à 1)
    Résultat: (evals_poids, liste de UEs du semestre sauf le sport)
    """
    modimpl: ModuleImpl = ModuleImpl.query.get(moduleimpl_id)
    evaluations = Evaluation.query.filter_by(moduleimpl_id=moduleimpl_id).all()
    ues = modimpl.formsemestre.query_ues(with_sport=False).all()
    ue_ids = [ue.id for ue in ues]
    evaluation_ids = [evaluation.id for evaluation in evaluations]
    evals_poids = pd.DataFrame(columns=ue_ids, index=evaluation_ids, dtype=float)
    for ue_poids in EvaluationUEPoids.query.join(
        EvaluationUEPoids.evaluation
    ).filter_by(moduleimpl_id=moduleimpl_id):
        try:
            evals_poids[ue_poids.ue_id][ue_poids.evaluation_id] = ue_poids.poids
        except KeyError as exc:
            pass  # poids vers des UE qui n'existent plus ou sont dans un autre semestre...

    # Initialise poids non enregistrés:
    default_poids = 1.0 if modimpl.module.ue.type == UE_SPORT else 0.0

    if np.isnan(evals_poids.values.flat).any():
        ue_coefs = modimpl.module.get_ue_coef_dict()
        for ue in ues:
            evals_poids[ue.id][evals_poids[ue.id].isna()] = (
                1 if ue_coefs.get(ue.id, default_poids) > 0 else 0
            )

    return evals_poids, ues


def moduleimpl_is_conforme(
    moduleimpl, evals_poids: pd.DataFrame, modules_coefficients: pd.DataFrame
) -> bool:
    """Vérifie que les évaluations de ce moduleimpl sont bien conformes
    au PN.
    Un module est dit *conforme* si et seulement si la somme des poids de ses
    évaluations vers une UE de coefficient non nul est non nulle.

    Argument: evals_poids: DataFrame, colonnes: UEs, Lignes: EVALs
    NB: les UEs dans evals_poids sont sans le bonus sport
    """
    nb_evals, nb_ues = evals_poids.shape
    if nb_evals == 0:
        return True  # modules vides conformes
    if nb_ues == 0:
        return False  # situation absurde (pas d'UE)
    if len(modules_coefficients) != nb_ues:
        raise ValueError("moduleimpl_is_conforme: nb ue incoherent")
    module_evals_poids = evals_poids.transpose().sum(axis=1).to_numpy() != 0
    check = all(
        (modules_coefficients[moduleimpl.module_id].to_numpy() != 0)
        == module_evals_poids
    )
    return check


class ModuleImplResultsClassic(ModuleImplResults):
    "Calcul des moyennes de modules des formations classiques"

    def compute_module_moy(self) -> pd.Series:
        """Calcule les moyennes des étudiants dans ce module

        Résultat: Series, lignes etud
            = la note (moyenne) de l'étudiant pour ce module.
            ou NaN si les évaluations (dans lesquelles l'étudiant a des notes)
            ne donnent pas de coef.
        """
        modimpl = ModuleImpl.query.get(self.moduleimpl_id)
        nb_etuds, nb_evals = self.evals_notes.shape
        if nb_etuds == 0:
            return pd.Series()
        evals_coefs = self.get_evaluations_coefs(modimpl).reshape(-1)
        assert evals_coefs.shape == (nb_evals,)
        evals_notes_20 = self.get_eval_notes_sur_20(modimpl)
        # Les coefs des évals pour chaque étudiant: là où il a des notes
        #  non neutralisées
        # (ABS n'est  pas neutralisée, mais ATTENTE et NEUTRALISE oui)
        # Note: les NaN sont remplacés par des 0 dans evals_notes
        #  et dans dans evals_poids_etuds
        #  (rappel: la comparaison est toujours False face à un NaN)
        # shape: (nb_etuds, nb_evals)
        coefs_stacked = np.stack([evals_coefs] * nb_etuds)
        evals_coefs_etuds = np.where(
            self.evals_notes.values > scu.NOTES_NEUTRALISE, coefs_stacked, 0
        )
        # Calcule la moyenne pondérée sur les notes disponibles:
        with np.errstate(invalid="ignore"):  # ignore les 0/0 (-> NaN)
            etuds_moy_module = np.sum(
                evals_coefs_etuds * evals_notes_20, axis=1
            ) / np.sum(evals_coefs_etuds, axis=1)

        self.etuds_moy_module = pd.Series(
            etuds_moy_module,
            index=self.evals_notes.index,
        )
        return self.etuds_moy_module