##############################################################################
# ScoDoc
# Copyright (c) 1999 - 2024 Emmanuel Viennet.  All rights reserved.
# See LICENSE
##############################################################################

"""Résultats semestres classiques (non APC)
"""
import time
import numpy as np
import pandas as pd
from sqlalchemy.sql import text

from flask import g, url_for

from app import db
from app import log
from app.comp import moy_mat, moy_mod, moy_sem, moy_ue, inscr_mod
from app.comp.res_compat import NotesTableCompat
from app.comp.bonus_spo import BonusSport
from app.models import ScoDocSiteConfig
from app.models.etudiants import Identite
from app.models.formsemestre import FormSemestre
from app.models.ues import UniteEns
from app.scodoc.codes_cursus import UE_SPORT
from app.scodoc.sco_exceptions import ScoValueError
from app.scodoc import sco_preferences
from app.scodoc.sco_utils import ModuleType


class ResultatsSemestreClassic(NotesTableCompat):
    """Résultats du semestre (formation classique): organisation des calculs."""

    _cached_attrs = NotesTableCompat._cached_attrs + (
        "modimpl_coefs",
        "modimpl_idx",
        "sem_matrix",
        "mod_rangs",
    )

    def __init__(self, formsemestre):
        super().__init__(formsemestre)
        self.sem_matrix: np.ndarray = None
        "sem_matrix : 2d-array (etuds x modimpls)"

        if not self.load_cached():
            t0 = time.time()
            self.compute()
            t1 = time.time()
            self.store()
            t2 = time.time()
            log(
                f"""+++ ResultatsSemestreClassic: cached formsemestre_id={
                    formsemestre.id} ({(t1-t0):g}s +{(t2-t1):g}s) +++"""
            )
        # recalculé (aussi rapide que de les cacher)
        self.moy_min = self.etud_moy_gen.min()
        self.moy_max = self.etud_moy_gen.max()
        self.moy_moy = self.etud_moy_gen.mean()

    def compute(self):
        "Charge les notes et inscriptions et calcule les moyennes d'UE et gen."
        self.sem_matrix, self.modimpls_results = notes_sem_load_matrix(
            self.formsemestre
        )
        self.modimpl_inscr_df = inscr_mod.df_load_modimpl_inscr(self.formsemestre)
        self.modimpl_coefs = np.array(
            [m.module.coefficient or 0.0 for m in self.formsemestre.modimpls_sorted]
        )
        self.modimpl_idx = {
            m.id: i for i, m in enumerate(self.formsemestre.modimpls_sorted)
        }
        "l'idx de la colonne du mod modimpl.id est modimpl_idx[modimpl.id]"

        modimpl_standards_mask = np.array(
            [
                (m.module.module_type == ModuleType.STANDARD)
                and (m.module.ue.type != UE_SPORT)
                for m in self.formsemestre.modimpls_sorted
            ]
        )
        (
            self.etud_moy_gen,
            self.etud_moy_ue,
            self.etud_coef_ue_df,
        ) = moy_ue.compute_ue_moys_classic(
            self.formsemestre,
            self.sem_matrix,
            self.ues,
            self.modimpl_inscr_df,
            self.modimpl_coefs,
            modimpl_standards_mask,
            block=self.formsemestre.block_moyennes,
        )
        # --- Modules de MALUS sur les UEs et la moyenne générale
        self.malus = moy_ue.compute_malus(
            self.formsemestre, self.sem_matrix, self.ues, self.modimpl_inscr_df
        )
        self.etud_moy_ue -= self.malus
        # ajuste la moyenne générale (à l'aide des coefs d'UE)
        self.etud_moy_gen -= (self.etud_coef_ue_df * self.malus).sum(
            axis=1
        ) / self.etud_coef_ue_df.sum(axis=1)

        # --- Bonus Sport & Culture
        bonus_class = ScoDocSiteConfig.get_bonus_sport_class()
        if bonus_class is not None:
            bonus: BonusSport = bonus_class(
                self.formsemestre,
                self.sem_matrix,
                self.ues,
                self.modimpl_inscr_df,
                self.modimpl_coefs,
                self.etud_moy_gen,
                self.etud_moy_ue,
            )
            self.bonus_ues = bonus.get_bonus_ues()
            if self.bonus_ues is not None:
                self.etud_moy_ue += self.bonus_ues  # somme les dataframes
                self.etud_moy_ue.clip(lower=0.0, upper=20.0, inplace=True)
            bonus_mg = bonus.get_bonus_moy_gen()
            if bonus_mg is None and self.bonus_ues is not None:
                # pas de bonus explicite sur la moyenne générale
                # on l'ajuste pour refléter les modifs d'UE, à l'aide des coefs d'UE.
                bonus_mg = (self.etud_coef_ue_df * self.bonus_ues).sum(
                    axis=1
                ) / self.etud_coef_ue_df.sum(axis=1)
                self.etud_moy_gen += bonus_mg
            elif bonus_mg is not None:
                # Applique le bonus moyenne générale renvoyé
                self.etud_moy_gen += bonus_mg

            # compat nt, utilisé pour l'afficher sur les bulletins:
            self.bonus = bonus_mg

        # --- UE capitalisées
        self.apply_capitalisation()

        # Clippe toutes les moyennes dans [0,20]
        self.etud_moy_ue.clip(lower=0.0, upper=20.0, inplace=True)
        self.etud_moy_gen.clip(lower=0.0, upper=20.0, inplace=True)

        # --- Classements:
        self.compute_rangs()

        # --- En option, moyennes par matières
        if sco_preferences.get_preference("bul_show_matieres", self.formsemestre.id):
            self.compute_moyennes_matieres()

    def compute_rangs(self):
        """Calcul des rangs (classements) dans le semestre (moy. gen.), les UE
        et les modules.
        """
        # rangs moy gen et UEs sont calculées par la méthode commune à toutes les formations:
        super().compute_rangs()
        # les rangs des modules n'existent que dans les formations classiques:
        self.mod_rangs = {}
        for modimpl_result in self.modimpls_results.values():
            # ne prend que les rangs sous forme de chaines:
            rangs = moy_sem.comp_ranks_series(modimpl_result.etuds_moy_module)[0]
            self.mod_rangs[modimpl_result.moduleimpl_id] = (
                rangs,
                modimpl_result.nb_inscrits_module,
            )

    def get_etud_mod_moy(self, moduleimpl_id: int, etudid: int) -> float:
        """La moyenne de l'étudiant dans le moduleimpl
        Result: valeur float (peut être NaN) ou chaîne "NI" (non inscrit ou DEM)
        """
        try:
            if self.modimpl_inscr_df[moduleimpl_id][etudid]:
                return self.modimpls_results[moduleimpl_id].etuds_moy_module[etudid]
        except KeyError:
            pass
        return "NI"

    def get_mod_stats(self, moduleimpl_id: int) -> dict:
        """Stats sur les notes obtenues dans un modimpl"""
        notes_series: pd.Series = self.modimpls_results[moduleimpl_id].etuds_moy_module
        nb_notes = len(notes_series)
        if not nb_notes:
            super().get_mod_stats(moduleimpl_id)
        return {
            # Series: Statistical methods from ndarray have been overridden to automatically
            # exclude missing data (currently represented as NaN)
            "moy": notes_series.mean(),  # donc sans prendre en compte les NaN
            "max": notes_series.max(),
            "min": notes_series.min(),
            "nb_notes": nb_notes,
            "nb_missing": sum(notes_series.isna()),
            "nb_valid_evals": sum(
                self.modimpls_results[moduleimpl_id].evaluations_completes
            ),
        }

    def modimpl_notes(
        self,
        modimpl_id: int,
        ue_id: int = None,
    ) -> np.ndarray:
        """Les notes moyennes des étudiants du sem. à ce modimpl dans cette ue.
        Utile pour stats bottom tableau recap.
        ue_id n'est pas utilisé ici (formations classiques)
        Résultat: 1d array of float
        """
        i = self.modimpl_idx[modimpl_id]
        return self.sem_matrix[:, i]

    def compute_moyennes_matieres(self):
        """Calcul les moyennes par matière. Doit être appelée au besoin, en fin de compute."""
        self.moyennes_matieres = moy_mat.compute_mat_moys_classic(
            self.formsemestre,
            self.sem_matrix,
            self.ues,
            self.modimpl_inscr_df,
            self.modimpl_coefs,
        )

    def compute_etud_ue_coef(self, etudid: int, ue: UniteEns) -> float:
        """Détermine le coefficient de l'UE pour cet étudiant.
        N'est utilisé que pour l'injection des UE capitalisées dans la
        moyenne générale.
        Coef = somme des coefs des modules de l'UE auxquels il est inscrit
        """
        coef = comp_etud_sum_coef_modules_ue(self.formsemestre.id, etudid, ue["ue_id"])
        if coef is not None:  # inscrit à au moins un module de cette UE
            return coef
        # arfff: aucun moyen de déterminer le coefficient de façon sûre
        log(
            f"""* oups: calcul coef UE impossible\nformsemestre_id='{self.formsemestre.id
            }'\netudid='{etudid}'\nue={ue}"""
        )
        etud = Identite.get_etud(etudid)
        raise ScoValueError(
            f"""<div class="scovalueerror"><p>Coefficient de l'UE capitalisée {ue.acronyme}
            impossible à déterminer pour l'étudiant <a href="{
            url_for("scolar.fiche_etud", scodoc_dept=g.scodoc_dept, etudid=etudid)
            }" class="discretelink">{etud.nom_disp()}</a></p>
            <p>Il faut <a href="{
            url_for("notes.formsemestre_edit_uecoefs", scodoc_dept=g.scodoc_dept,
                formsemestre_id=self.formsemestre.id, err_ue_id=ue["ue_id"],
            )
            }">saisir le coefficient de cette UE avant de continuer</a></p>
            </div>
            """,
            safe=True,
        )


def notes_sem_load_matrix(formsemestre: FormSemestre) -> tuple[np.ndarray, dict]:
    """Calcule la matrice des notes du semestre
    (charge toutes les notes, calcule les moyennes des modules
    et assemble la matrice)
    Resultat:
        sem_matrix : 2d-array (etuds x modimpls)
        modimpls_results dict { modimpl.id : ModuleImplResultsClassic }
    """
    modimpls_results = {}
    modimpls_notes = []
    etudids, etudids_actifs = formsemestre.etudids_actifs()
    for modimpl in formsemestre.modimpls_sorted:
        mod_results = moy_mod.ModuleImplResultsClassic(modimpl, etudids, etudids_actifs)
        etuds_moy_module = mod_results.compute_module_moy()
        modimpls_results[modimpl.id] = mod_results
        modimpls_notes.append(etuds_moy_module)
    return (
        notes_sem_assemble_matrix(modimpls_notes),
        modimpls_results,
    )


def notes_sem_assemble_matrix(modimpls_notes: list[pd.Series]) -> np.ndarray:
    """Réuni les notes moyennes des modules du semestre en une matrice

    modimpls_notes : liste des moyennes de module
                     (Series rendus par compute_module_moy, index: etud)
    Resultat: ndarray (etud x module)
    """
    if not modimpls_notes:
        return np.zeros((0, 0), dtype=float)
    modimpls_notes_arr = [s.values for s in modimpls_notes]
    modimpls_notes = np.stack(modimpls_notes_arr)
    # passe de (mod x etud) à (etud x mod)
    return modimpls_notes.T


def comp_etud_sum_coef_modules_ue(formsemestre_id, etudid, ue_id):
    """Somme des coefficients des modules de l'UE dans lesquels cet étudiant est inscrit
    ou None s'il n'y a aucun module.
    """
    # comme l'ancien notes_table.comp_etud_sum_coef_modules_ue
    # mais en raw sqlalchemy et la somme en SQL
    sql = text(
        """
    SELECT sum(mod.coefficient)
    FROM notes_modules mod, notes_moduleimpl mi, notes_moduleimpl_inscription ins
    WHERE mod.id = mi.module_id
    and ins.etudid = :etudid
    and ins.moduleimpl_id = mi.id
    and mi.formsemestre_id = :formsemestre_id
    and mod.ue_id = :ue_id
    """
    )
    cursor = db.session.execute(
        sql, {"etudid": etudid, "formsemestre_id": formsemestre_id, "ue_id": ue_id}
    )
    r = cursor.fetchone()
    if r is None:
        return None
    return r[0]