"""
Test calcul moyennes UE
"""
import numpy as np
from tests.unit import setup

from app import db
from app.comp import moy_ue
from app.comp import inscr_mod
from app.models import FormSemestre, Evaluation, ModuleImplInscription
from app.scodoc import sco_saisie_notes
from app.scodoc.sco_codes_parcours import UE_SPORT
from app.scodoc.sco_utils import NOTES_NEUTRALISE
from app.scodoc import sco_exceptions


def test_ue_moy(test_client):
    """Test calcul moyenne UE avec saisie des notes via ScoDoc"""
    ue_coefs = (1.0, 2.0, 3.0)  # coefs des modules vers les 3 UE
    nb_ues = len(ue_coefs)
    nb_mods = 2  # 2 modules
    (
        G,
        formation_id,
        formsemestre_id,
        evaluation_ids,
        ue1,
        ue2,
        ue3,
    ) = setup.build_modules_with_evaluations(ue_coefs=ue_coefs, nb_mods=nb_mods)
    assert len(evaluation_ids) == nb_mods
    formsemestre = FormSemestre.query.get(formsemestre_id)
    evaluation1 = Evaluation.query.get(evaluation_ids[0])
    evaluation2 = Evaluation.query.get(evaluation_ids[1])
    etud = G.create_etud(nom="test")
    G.inscrit_etudiant(formsemestre_id, etud)
    etudid = etud["etudid"]
    # e1 est l'éval du module 1, et e2 l'éval du module 2
    e1p1, e1p2, e1p3 = 1.0, 2.0, 3.0  # poids de l'éval 1 vers les UE 1, 2 et 3
    e2p1, e2p2, e2p3 = 3.0, 1.0, 5.0  # poids de l'éval 2 vers les UE
    evaluation1.set_ue_poids_dict({ue1.id: e1p1, ue2.id: e1p2, ue3.id: e1p3})
    evaluation2.set_ue_poids_dict({ue1.id: e2p1, ue2.id: e2p2, ue3.id: e2p3})
    # Les coefs des évaluations:
    coef_e1, coef_e2 = 7.0, 11.0
    evaluation1.coefficient = coef_e1
    evaluation2.coefficient = coef_e2
    # Les moduleimpls
    modimpls = [evaluation1.moduleimpl, evaluation2.moduleimpl]
    # Check inscriptions modules
    modimpl_inscr_df = inscr_mod.df_load_modimpl_inscr(formsemestre)
    assert (modimpl_inscr_df.values == np.array([[1, 1]])).all()
    # Coefs des modules vers les UE:
    modimpl_coefs_df, ues, modimpls = moy_ue.df_load_modimpl_coefs(formsemestre)
    assert modimpl_coefs_df.shape == (nb_ues, nb_mods)
    assert len(ues) == nb_ues
    assert len(modimpls) == nb_mods
    assert (modimpl_coefs_df.values == np.array([ue_coefs] * nb_mods).transpose()).all()

    # --- Change les notes et recalcule les moyennes
    # (rappel: on a deux évaluations: evaluation1, evaluation2, et un seul étudiant)
    def change_notes(n1, n2):
        # Saisie d'une note dans chaque éval
        _ = sco_saisie_notes.notes_add(G.default_user, evaluation1.id, [(etudid, n1)])
        _ = sco_saisie_notes.notes_add(G.default_user, evaluation2.id, [(etudid, n2)])
        # Recalcul des moyennes
        sem_cube, _, _ = moy_ue.notes_sem_load_cube(formsemestre)
        # Masque de tous les modules _sauf_ les bonus (sport)
        modimpl_mask = [
            modimpl.module.ue.type != UE_SPORT
            for modimpl in formsemestre.modimpls_sorted
        ]
        etuds = formsemestre.etuds.all()
        etud_moy_ue = moy_ue.compute_ue_moys_apc(
            sem_cube,
            etuds,
            modimpls,
            ues,
            modimpl_inscr_df,
            modimpl_coefs_df,
            modimpl_mask,
        )
        return etud_moy_ue

    # Cas simple: 1 eval / module, notes normales,
    # coefs non nuls.
    n1, n2 = 5.0, 13.0  # notes aux 2 evals (1 dans chaque module)
    etud_moy_ue = change_notes(n1, n2)
    assert etud_moy_ue.shape == (1, nb_ues)  # 1 étudiant
    assert etud_moy_ue[ue1.id][etudid] == (n1 + n2) / 2
    assert etud_moy_ue[ue2.id][etudid] == (n1 + n2) / 2
    assert etud_moy_ue[ue3.id][etudid] == (n1 + n2) / 2
    #
    # ABS à un module (note comptée comme 0)
    n1, n2 = None, 13.0  # notes aux 2 evals (1 dans chaque module)
    etud_moy_ue = change_notes(n1, n2)
    assert etud_moy_ue[ue1.id][etudid] == n2 / 2  # car n1 est zéro
    assert etud_moy_ue[ue2.id][etudid] == n2 / 2
    assert etud_moy_ue[ue3.id][etudid] == n2 / 2
    # EXC à un module
    n1, n2 = 5.0, NOTES_NEUTRALISE
    etud_moy_ue = change_notes(n1, n2)
    assert (etud_moy_ue.values == n1).all()
    # Désinscrit l'étudiant du module 2:
    inscr = ModuleImplInscription.query.filter_by(
        moduleimpl_id=evaluation2.moduleimpl.id, etudid=etudid
    ).first()
    db.session.delete(inscr)
    db.session.commit()
    modimpl_inscr_df = inscr_mod.df_load_modimpl_inscr(formsemestre)
    assert (modimpl_inscr_df.values == np.array([[1, 0]])).all()
    n1, n2 = 5.0, NOTES_NEUTRALISE
    # On ne doit pas pouvoir saisir de note sans être inscrit:
    exception_raised = False
    try:
        etud_moy_ue = change_notes(n1, n2)
    except sco_exceptions.NoteProcessError:
        exception_raised = True
    assert exception_raised
    # Recalcule les notes:
    sem_cube, _, _ = moy_ue.notes_sem_load_cube(formsemestre)
    etuds = formsemestre.etuds.all()
    modimpl_mask = [
        modimpl.module.ue.type != UE_SPORT for modimpl in formsemestre.modimpls_sorted
    ]
    etud_moy_ue = moy_ue.compute_ue_moys_apc(
        sem_cube, etuds, modimpls, ues, modimpl_inscr_df, modimpl_coefs_df, modimpl_mask
    )
    assert etud_moy_ue[ue1.id][etudid] == n1
    assert etud_moy_ue[ue2.id][etudid] == n1
    assert etud_moy_ue[ue3.id][etudid] == n1