"""
Ecrit par Matthias Hartmann.
"""
from datetime import date, datetime, time, timedelta
from pytz import UTC

from app import log, db
import app.scodoc.sco_utils as scu
from app.models.assiduites import Assiduite, Justificatif, compute_assiduites_justified
from app.models.etudiants import Identite
from app.models.formsemestre import FormSemestre, FormSemestreInscription
from app.scodoc import sco_formsemestre_inscriptions
from app.scodoc import sco_preferences
from app.scodoc import sco_cache
from app.scodoc import sco_etud
from flask_sqlalchemy.query import Query


class CountCalculator:
    """Classe qui gére le comptage des assiduités"""

    def __init__(
        self,
        morning: time = time(8, 0),
        noon: time = time(12, 0),
        after_noon: time = time(14, 00),
        evening: time = time(18, 0),
        skip_saturday: bool = True,
    ) -> None:
        self.morning: time = morning
        self.noon: time = noon
        self.after_noon: time = after_noon
        self.evening: time = evening
        self.skip_saturday: bool = skip_saturday

        delta_total: timedelta = datetime.combine(date.min, evening) - datetime.combine(
            date.min, morning
        )
        delta_lunch: timedelta = datetime.combine(
            date.min, after_noon
        ) - datetime.combine(date.min, noon)

        self.hour_per_day: float = (delta_total - delta_lunch).total_seconds() / 3600

        self.days: list[date] = []
        self.half_days: list[tuple[date, bool]] = []  # tuple -> (date, morning:bool)
        self.hours: float = 0.0

        self.count: int = 0

    def reset(self):
        """Remet à zero le compteur"""
        self.days = []
        self.half_days = []
        self.hours = 0.0
        self.count = 0

    def add_half_day(self, day: date, is_morning: bool = True):
        """Ajoute une demi journée dans le comptage"""
        key: tuple[date, bool] = (day, is_morning)
        if key not in self.half_days:
            self.half_days.append(key)

    def add_day(self, day: date):
        """Ajoute un jour dans le comptage"""
        if day not in self.days:
            self.days.append(day)

    def check_in_morning(self, period: tuple[datetime, datetime]) -> bool:
        """Vérifiée si la période donnée fait partie du matin
        (Test sur la date de début)
        """

        interval_morning: tuple[datetime, datetime] = (
            scu.localize_datetime(datetime.combine(period[0].date(), self.morning)),
            scu.localize_datetime(datetime.combine(period[0].date(), self.noon)),
        )

        in_morning: bool = scu.is_period_overlapping(
            period, interval_morning, bornes=False
        )
        return in_morning

    def check_in_evening(self, period: tuple[datetime, datetime]) -> bool:
        """Vérifie si la période fait partie de l'aprèm
        (test sur la date de début)
        """

        interval_evening: tuple[datetime, datetime] = (
            scu.localize_datetime(datetime.combine(period[0].date(), self.after_noon)),
            scu.localize_datetime(datetime.combine(period[0].date(), self.evening)),
        )

        in_evening: bool = scu.is_period_overlapping(period, interval_evening)

        return in_evening

    def compute_long_assiduite(self, assi: Assiduite):
        """Calcule les métriques sur une assiduité longue (plus d'un jour)"""

        pointer_date: date = assi.date_debut.date() + timedelta(days=1)
        start_hours: timedelta = assi.date_debut - scu.localize_datetime(
            datetime.combine(assi.date_debut, self.morning)
        )
        finish_hours: timedelta = assi.date_fin - scu.localize_datetime(
            datetime.combine(assi.date_fin, self.morning)
        )

        self.add_day(assi.date_debut.date())
        self.add_day(assi.date_fin.date())

        start_period: tuple[datetime, datetime] = (
            assi.date_debut,
            scu.localize_datetime(
                datetime.combine(assi.date_debut.date(), self.evening)
            ),
        )

        finish_period: tuple[datetime, datetime] = (
            scu.localize_datetime(datetime.combine(assi.date_fin.date(), self.morning)),
            assi.date_fin,
        )
        hours = 0.0
        for period in (start_period, finish_period):
            if self.check_in_evening(period):
                self.add_half_day(period[0].date(), False)
            if self.check_in_morning(period):
                self.add_half_day(period[0].date())

        while pointer_date < assi.date_fin.date():
            # TODO : Utiliser la préférence de département : workdays
            if pointer_date.weekday() < (6 - self.skip_saturday):
                self.add_day(pointer_date)
                self.add_half_day(pointer_date)
                self.add_half_day(pointer_date, False)
                self.hours += self.hour_per_day
                hours += self.hour_per_day

            pointer_date += timedelta(days=1)

        self.hours += finish_hours.total_seconds() / 3600
        self.hours += self.hour_per_day - (start_hours.total_seconds() / 3600)

    def compute_assiduites(self, assiduites: Query or list):
        """Calcule les métriques pour la collection d'assiduité donnée"""
        assi: Assiduite
        assiduites: list[Assiduite] = (
            assiduites.all() if isinstance(assiduites, Query) else assiduites
        )
        for assi in assiduites:
            self.count += 1
            delta: timedelta = assi.date_fin - assi.date_debut

            if delta.days > 0:
                self.compute_long_assiduite(assi)

                continue

            period: tuple[datetime, datetime] = (assi.date_debut, assi.date_fin)
            deb_date: date = assi.date_debut.date()
            if self.check_in_morning(period):
                self.add_half_day(deb_date)
            if self.check_in_evening(period):
                self.add_half_day(deb_date, False)

            self.add_day(deb_date)

            self.hours += delta.total_seconds() / 3600

    def to_dict(self) -> dict[str, int or float]:
        """Retourne les métriques sous la forme d'un dictionnaire"""
        return {
            "compte": self.count,
            "journee": len(self.days),
            "demi": len(self.half_days),
            "heure": round(self.hours, 2),
        }


def get_assiduites_stats(
    assiduites: Query, metric: str = "all", filtered: dict[str, object] = None
) -> dict[str, int or float]:
    """Compte les assiduités en fonction des filtres"""

    if filtered is not None:
        deb, fin = None, None
        for key in filtered:
            match key:
                case "etat":
                    assiduites = filter_assiduites_by_etat(assiduites, filtered[key])
                case "date_fin":
                    fin = filtered[key]
                case "date_debut":
                    deb = filtered[key]
                case "moduleimpl_id":
                    assiduites = filter_by_module_impl(assiduites, filtered[key])
                case "formsemestre":
                    assiduites = filter_by_formsemestre(
                        assiduites, Assiduite, filtered[key]
                    )
                case "est_just":
                    assiduites = filter_assiduites_by_est_just(
                        assiduites, filtered[key]
                    )
                case "user_id":
                    assiduites = filter_by_user_id(assiduites, filtered[key])

        if (deb, fin) != (None, None):
            assiduites = filter_by_date(assiduites, Assiduite, deb, fin)

    metrics: list[str] = metric.split(",")
    output: dict = {}
    calculator: CountCalculator = CountCalculator()

    if filtered is None or "split" not in filtered:
        calculator.compute_assiduites(assiduites)
        count: dict = calculator.to_dict()

        for key, val in count.items():
            if key in metrics:
                output[key] = val
        return output if output else count

    etats: list[str] = (
        filtered["etat"].split(",")
        if "etat" in filtered
        else ["absent", "present", "retard"]
    )

    for etat in etats:
        output[etat] = _count_assiduites_etat(etat, assiduites, calculator, metrics)
        if "est_just" not in filtered:
            output[etat]["justifie"] = _count_assiduites_etat(
                etat, assiduites, calculator, metrics, justifie=True
            )

    return output


def _count_assiduites_etat(
    etat: str,
    assiduites: Query,
    calculator: CountCalculator,
    metrics: list[str],
    justifie: bool = False,
):
    calculator.reset()
    etat_num: int = scu.EtatAssiduite.get(etat, -1)
    assiduites_etat: Query = assiduites.filter(Assiduite.etat == etat_num)
    if justifie:
        assiduites_etat = assiduites_etat.filter(Assiduite.est_just == True)

    calculator.compute_assiduites(assiduites_etat)
    count_etat: dict = calculator.to_dict()
    output_etat: dict = {}
    for key, val in count_etat.items():
        if key in metrics:
            output_etat[key] = val
    return output_etat if output_etat else count_etat


def filter_assiduites_by_etat(assiduites: Assiduite, etat: str) -> Query:
    """
    Filtrage d'une collection d'assiduites en fonction de leur état
    """
    etats: list[str] = list(etat.split(","))
    etats = [scu.EtatAssiduite.get(e, -1) for e in etats]
    return assiduites.filter(Assiduite.etat.in_(etats))


def filter_assiduites_by_est_just(assiduites: Assiduite, est_just: bool) -> Query:
    """
    Filtrage d'une collection d'assiduites en fonction de s'ils sont justifiés
    """
    return assiduites.filter(Assiduite.est_just == est_just)


def filter_by_user_id(
    collection: Assiduite or Justificatif,
    user_id: int,
) -> Query:
    """
    Filtrage d'une collection en fonction de l'user_id
    """
    return collection.filter_by(user_id=user_id)


def filter_by_date(
    collection: Assiduite or Justificatif,
    collection_cls: Assiduite or Justificatif,
    date_deb: datetime = None,
    date_fin: datetime = None,
    strict: bool = False,
) -> Query:
    """
    Filtrage d'une collection d'assiduites en fonction d'une date
    """
    if date_deb is None:
        date_deb = datetime.min
    if date_fin is None:
        date_fin = datetime.max

    date_deb = scu.localize_datetime(date_deb)
    date_fin = scu.localize_datetime(date_fin)
    if not strict:
        return collection.filter(
            collection_cls.date_debut <= date_fin, collection_cls.date_fin >= date_deb
        )
    return collection.filter(
        collection_cls.date_debut < date_fin, collection_cls.date_fin > date_deb
    )


def filter_justificatifs_by_etat(justificatifs: Justificatif, etat: str) -> Query:
    """
    Filtrage d'une collection de justificatifs en fonction de leur état
    """
    etats: list[str] = list(etat.split(","))
    etats = [scu.EtatJustificatif.get(e, -1) for e in etats]
    return justificatifs.filter(Justificatif.etat.in_(etats))


def filter_by_module_impl(assiduites: Assiduite, module_impl_id: int or None) -> Query:
    """
    Filtrage d'une collection d'assiduites en fonction de l'ID du module_impl
    """
    return assiduites.filter(Assiduite.moduleimpl_id == module_impl_id)


def filter_by_formsemestre(
    collection_query: Assiduite or Justificatif,
    collection_class: Assiduite or Justificatif,
    formsemestre: FormSemestre,
) -> Query:
    """
    Filtrage d'une collection en fonction d'un formsemestre
    """

    if formsemestre is None:
        return collection_query.filter(False)

    collection_result = (
        collection_query.join(Identite, collection_class.etudid == Identite.id)
        .join(
            FormSemestreInscription,
            Identite.id == FormSemestreInscription.etudid,
        )
        .filter(FormSemestreInscription.formsemestre_id == formsemestre.id)
    )

    form_date_debut = formsemestre.date_debut + timedelta(days=1)
    form_date_fin = formsemestre.date_fin + timedelta(days=1)

    collection_result = collection_result.filter(
        collection_class.date_debut >= form_date_debut
    )

    return collection_result.filter(collection_class.date_fin <= form_date_fin)


def justifies(justi: Justificatif, obj: bool = False) -> list[int] or Query:
    """
    Retourne la liste des assiduite_id qui sont justifié par la justification
    Une assiduité est justifiée si elle est COMPLETEMENT ou PARTIELLEMENT
    comprise dans la plage du justificatif
    et que l'état du justificatif est "valide".
    Renvoie des id si obj == False, sinon les Assiduités
    """

    if justi.etat != scu.EtatJustificatif.VALIDE:
        return []

    assiduites_query: Assiduite = Assiduite.query.filter_by(etudid=justi.etudid)
    assiduites_query = assiduites_query.filter(
        Assiduite.date_debut >= justi.date_debut, Assiduite.date_fin <= justi.date_fin
    )

    if not obj:
        return [assi.id for assi in assiduites_query.all()]

    return assiduites_query


def get_all_justified(
    etudid: int,
    date_deb: datetime = None,
    date_fin: datetime = None,
    moduleimpl_id: int = None,
) -> Query:
    """Retourne toutes les assiduités justifiées sur une période"""

    if date_deb is None:
        date_deb = datetime.min
    if date_fin is None:
        date_fin = datetime.max

    date_deb = scu.localize_datetime(date_deb)
    date_fin = scu.localize_datetime(date_fin)
    justified: Query = Assiduite.query.filter_by(est_just=True, etudid=etudid)
    if moduleimpl_id is not None:
        justified = justified.filter_by(moduleimpl_id=moduleimpl_id)
    after = filter_by_date(
        justified,
        Assiduite,
        date_deb,
        date_fin,
    )
    return after


def create_absence(
    date_debut: datetime,
    date_fin: datetime,
    etudid: int,
    description: str = None,
    est_just: bool = False,
) -> int:
    etud: Identite = Identite.query.filter_by(etudid=etudid).first_or_404()
    assiduite_unique: Assiduite = Assiduite.create_assiduite(
        etud=etud,
        date_debut=date_debut,
        date_fin=date_fin,
        etat=scu.EtatAssiduite.ABSENT,
        description=description,
    )
    db.session.add(assiduite_unique)

    db.session.commit()
    if est_just:
        justi = Justificatif.create_justificatif(
            etud=etud,
            date_debut=date_debut,
            date_fin=date_fin,
            etat=scu.EtatJustificatif.VALIDE,
            raison=description,
        )
        db.session.add(justi)
        db.session.commit()

        compute_assiduites_justified(etud.id, [justi])

    calculator: CountCalculator = CountCalculator()
    calculator.compute_assiduites([assiduite_unique])
    return calculator.to_dict()["demi"]


# Gestion du cache
def get_assiduites_count(etudid: int, sem: dict) -> tuple[int, int]:
    """Les comptes d'absences de cet étudiant dans ce semestre:
    tuple (nb abs non justifiées, nb abs justifiées)
    Utilise un cache.
    """
    metrique = sco_preferences.get_preference("assi_metrique", sem["formsemestre_id"])
    return get_assiduites_count_in_interval(
        etudid,
        sem["date_debut_iso"],
        sem["date_fin_iso"],
        scu.translate_assiduites_metric(metrique),
    )


def formsemestre_get_assiduites_count(
    etudid: int, formsemestre: FormSemestre, moduleimpl_id: int = None
) -> tuple[int, int]:
    """Les comptes d'absences de cet étudiant dans ce semestre:
    tuple (nb abs non justifiées, nb abs justifiées)
    Utilise un cache.
    """
    metrique = sco_preferences.get_preference("assi_metrique", formsemestre.id)
    return get_assiduites_count_in_interval(
        etudid,
        date_debut=scu.localize_datetime(
            datetime.combine(formsemestre.date_debut, time(8, 0))
        ),
        date_fin=scu.localize_datetime(
            datetime.combine(formsemestre.date_fin, time(18, 0))
        ),
        metrique=scu.translate_assiduites_metric(metrique),
        moduleimpl_id=moduleimpl_id,
    )


def get_assiduites_count_in_interval(
    etudid,
    date_debut_iso: str = "",
    date_fin_iso: str = "",
    metrique="demi",
    date_debut: datetime = None,
    date_fin: datetime = None,
    moduleimpl_id: int = None,
):
    """Les comptes d'absences de cet étudiant entre ces deux dates, incluses:
    tuple (nb abs, nb abs justifiées)
    On peut spécifier les dates comme datetime ou iso.
    Utilise un cache.
    """
    date_debut_iso = date_debut_iso or date_debut.isoformat()
    date_fin_iso = date_fin_iso or date_fin.isoformat()
    key = f"{etudid}_{date_debut_iso}_{date_fin_iso}{metrique}_assiduites"

    r = sco_cache.AbsSemEtudCache.get(key)
    if not r or moduleimpl_id is not None:
        date_debut: datetime = date_debut or datetime.fromisoformat(date_debut_iso)
        date_fin: datetime = date_fin or datetime.fromisoformat(date_fin_iso)

        assiduites: Query = Assiduite.query.filter_by(etudid=etudid)
        assiduites = assiduites.filter(Assiduite.etat == scu.EtatAssiduite.ABSENT)
        justificatifs: Justificatif = Justificatif.query.filter_by(etudid=etudid)

        assiduites = filter_by_date(assiduites, Assiduite, date_debut, date_fin)

        if moduleimpl_id is not None:
            assiduites = assiduites.filter_by(moduleimpl_id=moduleimpl_id)

        justificatifs = filter_by_date(
            justificatifs, Justificatif, date_debut, date_fin
        )
        calculator: CountCalculator = CountCalculator()
        calculator.compute_assiduites(assiduites)
        nb_abs: dict = calculator.to_dict()[metrique]

        abs_just: list[Assiduite] = get_all_justified(
            etudid, date_debut, date_fin, moduleimpl_id
        )

        calculator.reset()
        calculator.compute_assiduites(abs_just)
        nb_abs_just: dict = calculator.to_dict()[metrique]

        r = (nb_abs, nb_abs_just)
        if moduleimpl_id is None:
            ans = sco_cache.AbsSemEtudCache.set(key, r)
            if not ans:
                log("warning: get_assiduites_count failed to cache")
    return r


def invalidate_assiduites_count(etudid, sem):
    """Invalidate (clear) cached counts"""
    date_debut = sem["date_debut_iso"]
    date_fin = sem["date_fin_iso"]
    for met in scu.AssiduitesMetrics.TAG:
        key = str(etudid) + "_" + date_debut + "_" + date_fin + f"{met}_assiduites"
        sco_cache.AbsSemEtudCache.delete(key)


def invalidate_assiduites_count_sem(sem):
    """Invalidate (clear) cached abs counts for all the students of this semestre"""
    inscriptions = (
        sco_formsemestre_inscriptions.do_formsemestre_inscription_listinscrits(
            sem["formsemestre_id"]
        )
    )
    for ins in inscriptions:
        invalidate_assiduites_count(ins["etudid"], sem)


def invalidate_assiduites_etud_date(etudid, date: datetime):
    """Doit etre appelé à chaque modification des assiduites
    pour cet étudiant et cette date.
    Invalide cache absence et caches semestre
    """
    from app.scodoc import sco_compute_moy

    # Semestres a cette date:
    etud = sco_etud.get_etud_info(etudid=etudid, filled=True)
    if len(etud) == 0:
        return
    else:
        etud = etud[0]
    sems = [
        sem
        for sem in etud["sems"]
        if scu.is_iso_formated(sem["date_debut_iso"], True).replace(tzinfo=UTC)
        <= date.replace(tzinfo=UTC)
        and scu.is_iso_formated(sem["date_fin_iso"], True).replace(tzinfo=UTC)
        >= date.replace(tzinfo=UTC)
    ]

    # Invalide les PDF et les absences:
    for sem in sems:
        # Inval cache bulletin et/ou note_table
        if sco_compute_moy.formsemestre_expressions_use_abscounts(
            sem["formsemestre_id"]
        ):
            # certaines formules utilisent les absences
            pdfonly = False
        else:
            # efface toujours le PDF car il affiche en général les absences
            pdfonly = True

        sco_cache.invalidate_formsemestre(
            formsemestre_id=sem["formsemestre_id"], pdfonly=pdfonly
        )

        # Inval cache compteurs absences:
        invalidate_assiduites_count(etudid, sem)


def simple_invalidate_cache(obj: dict, etudid: str or int = None):
    """Invalide le cache de l'étudiant et du / des semestres"""
    date_debut = (
        obj["date_debut"]
        if isinstance(obj["date_debut"], datetime)
        else scu.is_iso_formated(obj["date_debut"], True)
    )
    date_fin = (
        obj["date_fin"]
        if isinstance(obj["date_fin"], datetime)
        else scu.is_iso_formated(obj["date_fin"], True)
    )
    etudid = etudid if etudid is not None else obj["etudid"]
    invalidate_assiduites_etud_date(etudid, date_debut)
    invalidate_assiduites_etud_date(etudid, date_fin)