##############################################################################
# ScoDoc
# Copyright (c) 1999 - 2022 Emmanuel Viennet.  All rights reserved.
# See LICENSE
##############################################################################

"""Résultats semestres classiques (non APC)
"""
import numpy as np
import pandas as pd
from app.comp import moy_mod, moy_ue, moy_sem, inscr_mod
from app.comp.res_common import NotesTableCompat
from app.models.formsemestre import FormSemestre


class ResultatsSemestreClassic(NotesTableCompat):
    """Résultats du semestre (formation classique): organisation des calculs."""

    _cached_attrs = NotesTableCompat._cached_attrs + (
        "modimpl_coefs",
        "modimpl_idx",
        "sem_matrix",
    )

    def __init__(self, formsemestre):
        super().__init__(formsemestre)

        if not self.load_cached():
            self.compute()
            self.store()
        # recalculé (aussi rapide que de les cacher)
        self.moy_min = self.etud_moy_gen.min()
        self.moy_max = self.etud_moy_gen.max()
        self.moy_moy = self.etud_moy_gen.mean()

    def compute(self):
        "Charge les notes et inscriptions et calcule les moyennes d'UE et gen."
        self.sem_matrix, self.modimpls_results = notes_sem_load_matrix(
            self.formsemestre
        )
        self.modimpl_inscr_df = inscr_mod.df_load_modimpl_inscr(self.formsemestre)
        self.modimpl_coefs = np.array(
            [m.module.coefficient for m in self.formsemestre.modimpls]
        )
        self.modimpl_idx = {m.id: i for i, m in enumerate(self.formsemestre.modimpls)}
        "l'idx de la colonne du mod modimpl.id est modimpl_idx[modimpl.id]"

        self.etud_moy_gen, self.etud_moy_ue = moy_ue.compute_ue_moys_classic(
            self.formsemestre,
            self.sem_matrix,
            self.ues,
            self.modimpl_inscr_df,
            self.modimpl_coefs,
        )
        self.etud_moy_gen_ranks = moy_sem.comp_ranks_series(self.etud_moy_gen)

    def get_etud_mod_moy(self, moduleimpl_id: int, etudid: int) -> float:
        """La moyenne de l'étudiant dans le moduleimpl
        Result: valeur float (peut être NaN) ou chaîne "NI" (non inscrit ou DEM)
        """
        return self.modimpls_results[moduleimpl_id].etuds_moy_module.get(etudid, "NI")

    def get_mod_stats(self, moduleimpl_id: int) -> dict:
        """Stats sur les notes obtenues dans un modimpl"""
        notes_series: pd.Series = self.modimpls_results[moduleimpl_id].etuds_moy_module
        nb_notes = len(notes_series)
        if not nb_notes:
            super().get_mod_stats(moduleimpl_id)
        return {
            # Series: Statistical methods from ndarray have been overridden to automatically
            # exclude missing data (currently represented as NaN)
            "moy": notes_series.mean(),  # donc sans prendre en compte les NaN
            "max": notes_series.max(),
            "min": notes_series.min(),
            "nb_notes": nb_notes,
            "nb_missing": sum(notes_series.isna()),
            "nb_valid_evals": sum(
                self.modimpls_results[moduleimpl_id].evaluations_completes
            ),
        }


def notes_sem_load_matrix(formsemestre: FormSemestre) -> tuple:
    """Calcule la matrice des notes du semestre
    (charge toutes les notes, calcule les moyenne des modules
    et assemble la matrice)
    Resultat:
        sem_matrix : 2d-array (etuds x modimpls)
        modimpls_results dict { modimpl.id : ModuleImplResultsClassic }
    """
    modimpls_results = {}
    modimpls_notes = []
    for modimpl in formsemestre.modimpls:
        mod_results = moy_mod.ModuleImplResultsClassic(modimpl)
        etuds_moy_module = mod_results.compute_module_moy()
        modimpls_results[modimpl.id] = mod_results
        modimpls_notes.append(etuds_moy_module)
    return (
        notes_sem_assemble_matrix(modimpls_notes),
        modimpls_results,
    )


def notes_sem_assemble_matrix(modimpls_notes: list[pd.Series]) -> np.ndarray:
    """Réuni les notes moyennes des modules du semestre en une matrice

    modimpls_notes : liste des moyennes de module
                     (Series rendus par compute_module_moy, index: etud)
    Resultat: ndarray (etud x module)
    """
    modimpls_notes_arr = [s.values for s in modimpls_notes]
    modimpls_notes = np.stack(modimpls_notes_arr)
    # passe de (mod x etud) à (etud x mod)
    return modimpls_notes.T