diff --git a/app/auth/models.py b/app/auth/models.py index a293685e..073f687e 100644 --- a/app/auth/models.py +++ b/app/auth/models.py @@ -306,6 +306,13 @@ class User(UserMixin, db.Model): role, dept = UserRole.role_dept_from_string(r_d) self.add_role(role, dept) + # Set cas_id using regexp if configured: + exp = ScoDocSiteConfig.get("cas_uid_from_mail_regexp") + if exp and self.email_institutionnel: + cas_id = ScoDocSiteConfig.extract_cas_id(self.email_institutionnel) + if cas_id is not None: + self.cas_id = cas_id + def get_token(self, expires_in=3600): "Un jeton pour cet user. Stocké en base, non commité." now = datetime.utcnow() diff --git a/app/forms/main/config_cas.py b/app/forms/main/config_cas.py index f68aa3cc..7e2b73c5 100644 --- a/app/forms/main/config_cas.py +++ b/app/forms/main/config_cas.py @@ -30,8 +30,17 @@ Formulaire configuration CAS """ from flask_wtf import FlaskForm -from wtforms import BooleanField, SubmitField +from wtforms import BooleanField, SubmitField, ValidationError from wtforms.fields.simple import FileField, StringField +from wtforms.validators import Optional + +from app.models import ScoDocSiteConfig + + +def check_cas_uid_from_mail_regexp(form, field): + "Vérifie la regexp fournie pur l'extraction du CAS id" + if not ScoDocSiteConfig.cas_uid_from_mail_regexp_is_valid(field.data): + raise ValidationError("expression régulière invalide") class ConfigCASForm(FlaskForm): @@ -50,7 +59,8 @@ class ConfigCASForm(FlaskForm): ) cas_login_route = StringField( label="Route du login CAS", - description="""ajouté à l'URL du serveur: exemple /cas (si commence par /, part de la racine)""", + description="""ajouté à l'URL du serveur: exemple /cas + (si commence par /, part de la racine)""", default="/cas", ) cas_logout_route = StringField( @@ -70,6 +80,18 @@ class ConfigCASForm(FlaskForm): comptes utilisateurs.""", ) + cas_uid_from_mail_regexp = StringField( + label="Expression pour extraire l'identifiant utilisateur", + description="""regexp python appliquée au mail institutionnel de l'utilisateur, + dont le premier groupe doit donner l'identifiant CAS. + Si non fournie, le super-admin devra saisir cet identifiant pour chaque compte. + Par exemple, (.*)@ indique que le mail sans le domaine (donc toute + la partie avant le @) est l'identifiant. + Pour prendre le mail complet, utiliser (.*). + """, + validators=[Optional(), check_cas_uid_from_mail_regexp], + ) + cas_ssl_verify = BooleanField("Vérification du certificat SSL") cas_ssl_certificate_file = FileField( label="Certificat (PEM)", diff --git a/app/models/config.py b/app/models/config.py index 15da815f..60ce884b 100644 --- a/app/models/config.py +++ b/app/models/config.py @@ -5,6 +5,7 @@ import json import urllib.parse +import re from flask import flash from app import current_app, db, log @@ -103,6 +104,7 @@ class ScoDocSiteConfig(db.Model): "cas_logout_route": str, "cas_validate_route": str, "cas_attribute_id": str, + "cas_uid_from_mail_regexp": str, # Assiduité "morning_time": str, "lunch_time": str, @@ -395,6 +397,41 @@ class ScoDocSiteConfig(db.Model): data_links = json.dumps(links_dict) cls.set("personalized_links", data_links) + @classmethod + def extract_cas_id(cls, email_addr: str) -> str | None: + "Extract cas_id from maill, using regexp in config. None if not possible." + exp = cls.get("cas_uid_from_mail_regexp") + if not exp or not email_addr: + return None + try: + match = re.search(exp, email_addr) + except re.error: + log("error extracting CAS id from '{email_addr}' using regexp '{exp}'") + return None + if not match: + log("no match extracting CAS id from '{email_addr}' using regexp '{exp}'") + return None + try: + cas_id = match.group(1) + except IndexError: + log( + "no group found extracting CAS id from '{email_addr}' using regexp '{exp}'" + ) + return None + return cas_id + + @classmethod + def cas_uid_from_mail_regexp_is_valid(cls, exp: str) -> bool: + "True si l'expression régulière semble valide" + # check that it compiles + try: + pattern = re.compile(exp) + except re.error: + return False + # and returns at least one group on a simple cannonical address + match = pattern.search("emmanuel@exemple.fr") + return len(match.groups()) > 0 + @classmethod def assi_get_rounded_time(cls, label: str, default: str) -> float: "Donne l'heure stockée dans la config globale sous label, en float arrondi au quart d'heure" diff --git a/app/templates/config_cas.j2 b/app/templates/config_cas.j2 index ce45cfa6..ca500b1c 100644 --- a/app/templates/config_cas.j2 +++ b/app/templates/config_cas.j2 @@ -23,6 +23,7 @@ {{ wtf.form_field(form.cas_logout_route) }} {{ wtf.form_field(form.cas_validate_route) }} {{ wtf.form_field(form.cas_attribute_id) }} + {{ wtf.form_field(form.cas_uid_from_mail_regexp) }}