644 lines
20 KiB
Python
644 lines
20 KiB
Python
# -*- mode: python -*-
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
import string
|
|
import traceback
|
|
import psycopg2
|
|
import psycopg2.pool
|
|
import psycopg2.extras
|
|
|
|
from flask import g
|
|
|
|
import app.scodoc.sco_utils as scu
|
|
from app.scodoc.notes_log import log
|
|
from app.scodoc.sco_exceptions import ScoException, ScoValueError, NoteProcessError
|
|
from types import StringType
|
|
from cgi import escape
|
|
import datetime
|
|
|
|
quote_html = escape
|
|
|
|
|
|
def quote_dict(d):
|
|
"html quote all values in dict"
|
|
for k in d.keys():
|
|
v = d[k]
|
|
if type(v) == StringType:
|
|
d[k] = quote_html(v, quote=True)
|
|
|
|
|
|
def unquote(s):
|
|
"inverse of quote"
|
|
# pas d'inverse de cgi.escape
|
|
# ne traite que &
|
|
# XX voir aussi sco_utils.unescape_html
|
|
return s.replace("&", "&")
|
|
|
|
|
|
def open_dept_connection():
|
|
"""Open a connection to the current dept db"""
|
|
# log("open_dept_connection to " + scu.get_db_cnx_string()) # XXX
|
|
return psycopg2.connect(scu.get_db_cnx_string())
|
|
|
|
|
|
def close_dept_connection():
|
|
"""Commit and close dept db."""
|
|
# log("close_dept_connection to " + scu.get_db_cnx_string()) # XXX
|
|
g.db_conn.commit()
|
|
g.db_conn.close()
|
|
|
|
|
|
# Essai bien plus simple pour Flask:
|
|
def GetDBConnexion(autocommit=True): # on n'utilise plus autocommit
|
|
return g.db_conn
|
|
|
|
|
|
# Nota: on pourrait maintenant utiliser psycopg2.extras.DictCursor
|
|
class ScoDocCursor(psycopg2.extensions.cursor):
|
|
"""A database cursor emulating some methods of psycopg v1 cursors"""
|
|
|
|
def dictfetchall(self):
|
|
col_names = [d[0] for d in self.description]
|
|
return [dict(zip(col_names, row)) for row in self.fetchall()]
|
|
|
|
def dictfetchone(self):
|
|
col_names = [d[0] for d in self.description]
|
|
row = self.fetchone()
|
|
if row:
|
|
return dict(zip(col_names, row))
|
|
else:
|
|
return {}
|
|
|
|
|
|
def SimpleQuery(context, query, args, cursor=None):
|
|
if not cursor:
|
|
cnx = GetDBConnexion()
|
|
cursor = cnx.cursor(cursor_factory=ScoDocCursor)
|
|
# log( 'SimpleQuery(%s)' % (query % args) )
|
|
cursor.execute(query, args)
|
|
return cursor
|
|
|
|
|
|
def SimpleDictFetch(context, query, args, cursor=None):
|
|
cursor = SimpleQuery(context, query, args, cursor=cursor)
|
|
return cursor.dictfetchall()
|
|
|
|
|
|
def DBInsertDict(cnx, table, vals, commit=0, convert_empty_to_nulls=1):
|
|
"insert into table values in dict 'vals'"
|
|
cursor = cnx.cursor(cursor_factory=ScoDocCursor)
|
|
if convert_empty_to_nulls:
|
|
for col in vals.keys():
|
|
if vals[col] == "":
|
|
vals[col] = None
|
|
# open('/tmp/vals','a').write( str(vals) + '\n' )
|
|
cols = list(vals.keys())
|
|
colnames = ",".join(cols)
|
|
fmt = ",".join(["%%(%s)s" % col for col in cols])
|
|
# print 'insert into %s (%s) values (%s)' % (table,colnames,fmt)
|
|
oid = None
|
|
try:
|
|
if vals:
|
|
cursor.execute(
|
|
"insert into %s (%s) values (%s)" % (table, colnames, fmt), vals
|
|
)
|
|
else:
|
|
cursor.execute("insert into %s default values" % table)
|
|
oid = cursor.lastrowid
|
|
except:
|
|
log("DBInsertDict: EXCEPTION !")
|
|
log("DBInsertDict: table=%s, vals=%s" % (str(table), str(vals)))
|
|
log("DBInsertDict: commit (exception)")
|
|
cnx.commit() # get rid of this transaction
|
|
raise # and re-raise exception
|
|
if commit:
|
|
log("DBInsertDict: commit (requested)")
|
|
cnx.commit()
|
|
return oid
|
|
|
|
|
|
def DBSelectArgs(
|
|
cnx,
|
|
table,
|
|
vals,
|
|
what=["*"],
|
|
sortkey=None,
|
|
test="=",
|
|
operator="and",
|
|
distinct=True,
|
|
aux_tables=[],
|
|
id_name=None,
|
|
limit="",
|
|
offset="",
|
|
):
|
|
"""Select * from table where values match dict vals.
|
|
Returns cnx, columns_names, list of tuples
|
|
aux_tables = ( tablename, id_name )
|
|
"""
|
|
cursor = cnx.cursor(cursor_factory=ScoDocCursor)
|
|
if sortkey:
|
|
orderby = " order by " + sortkey
|
|
else:
|
|
orderby = ""
|
|
if distinct:
|
|
distinct = " distinct "
|
|
else:
|
|
distinct = ""
|
|
if limit != "":
|
|
limit = " LIMIT %d" % limit
|
|
if not offset:
|
|
offset = ""
|
|
if offset != "":
|
|
offset = " OFFSET %d" % offset
|
|
operator = " " + operator + " "
|
|
# liste des tables (apres "from")
|
|
tables = [table] + [x[0] for x in aux_tables]
|
|
for i in range(len(tables)):
|
|
tables[i] = "%s T%d" % (tables[i], i)
|
|
tables = ", ".join(tables)
|
|
# condition (apres "where")
|
|
cond = ""
|
|
i = 1
|
|
cl = []
|
|
for (_, aux_id) in aux_tables:
|
|
cl.append("T0.%s = T%d.%s" % (id_name, i, aux_id))
|
|
i = i + 1
|
|
cond += " and ".join(cl)
|
|
# Traitement des expressions régulières:
|
|
# n'autorise pas d'expressions
|
|
if test == "~":
|
|
for k in vals.keys():
|
|
vals[k] = vals[k].translate(string.maketrans("", ""), '%*()+=&|[]"`')
|
|
|
|
if vals:
|
|
if aux_tables: # paren
|
|
cond += " AND ( "
|
|
cond += operator.join(
|
|
["T0.%s%s%%(%s)s" % (x, test, x) for x in vals.keys() if vals[x] != None]
|
|
)
|
|
cnuls = " and ".join(
|
|
["T0.%s is NULL" % x for x in vals.keys() if vals[x] is None]
|
|
)
|
|
if cnuls:
|
|
if cond:
|
|
cond = cond + " and " + cnuls
|
|
else:
|
|
cond = cnuls
|
|
# close paren
|
|
if aux_tables:
|
|
cond += ") "
|
|
if cond:
|
|
cond = " where " + cond
|
|
#
|
|
req = (
|
|
"select "
|
|
+ distinct
|
|
+ ", ".join(what)
|
|
+ " from "
|
|
+ tables
|
|
+ cond
|
|
+ orderby
|
|
+ limit
|
|
+ offset
|
|
)
|
|
# open('/tmp/select.log','a').write( req % vals + '\n' )
|
|
try:
|
|
cursor.execute(req, vals)
|
|
except:
|
|
log('Exception in DBSelectArgs:\n\treq="%s"\n\tvals="%s"\n' % (req, vals))
|
|
log(traceback.format_exc())
|
|
raise ScoException()
|
|
return cursor.dictfetchall()
|
|
|
|
|
|
def DBUpdateArgs(cnx, table, vals, where=None, commit=False, convert_empty_to_nulls=1):
|
|
if not vals or where is None:
|
|
return
|
|
cursor = cnx.cursor(cursor_factory=ScoDocCursor)
|
|
if convert_empty_to_nulls:
|
|
for col in vals.keys():
|
|
if vals[col] == "":
|
|
vals[col] = None
|
|
s = ", ".join(["%s=%%(%s)s" % (x, x) for x in vals.keys()])
|
|
try:
|
|
req = "update " + table + " set " + s + " where " + where
|
|
cursor.execute(req, vals)
|
|
# log('req=%s\n'%req)
|
|
# log('vals=%s\n'%vals)
|
|
except:
|
|
cnx.commit() # get rid of this transaction
|
|
log('Exception in DBUpdateArgs:\n\treq="%s"\n\tvals="%s"\n' % (req, vals))
|
|
raise # and re-raise exception
|
|
if commit:
|
|
cnx.commit()
|
|
|
|
|
|
def DBDelete(cnx, table, colid, val, commit=False):
|
|
cursor = cnx.cursor(cursor_factory=ScoDocCursor)
|
|
try:
|
|
cursor.execute(
|
|
"delete from " + table + " where %s=%%(%s)s" % (colid, colid), {colid: val}
|
|
)
|
|
except:
|
|
cnx.commit() # get rid of this transaction
|
|
raise # and re-raise exception
|
|
if commit:
|
|
cnx.commit()
|
|
|
|
|
|
# --------------------------------------------------------------------
|
|
|
|
# REQLOG = open('/tmp/req.log', 'w') # DEBUG
|
|
# REQN = 0
|
|
|
|
|
|
class EditableTable(object):
|
|
"""--- generic class: SQL table with create/edit/list/delete"""
|
|
|
|
def __init__(
|
|
self,
|
|
table_name,
|
|
id_name,
|
|
dbfields,
|
|
sortkey=None,
|
|
output_formators={},
|
|
input_formators={},
|
|
aux_tables=[],
|
|
convert_null_outputs_to_empty=True,
|
|
callback_on_write=None,
|
|
allow_set_id=False,
|
|
html_quote=True,
|
|
fields_creators={}, # { field : [ sql_command_to_create_it ] }
|
|
filter_nulls=True, # dont allow to set fields to null
|
|
):
|
|
self.table_name = table_name
|
|
self.id_name = id_name
|
|
self.aux_tables = aux_tables
|
|
self.dbfields = dbfields
|
|
self.sortkey = sortkey
|
|
self.output_formators = output_formators
|
|
self.input_formators = input_formators
|
|
self.convert_null_outputs_to_empty = convert_null_outputs_to_empty
|
|
self.callback_on_write = (
|
|
callback_on_write # called after each modification (USELESS and unused)
|
|
)
|
|
self.allow_set_id = allow_set_id
|
|
self.html_quote = html_quote
|
|
self.fields_creators = fields_creators
|
|
self.filter_nulls = filter_nulls
|
|
self.sql_default_values = None
|
|
|
|
def create(self, cnx, args, has_uniq_values=False):
|
|
"create object in table"
|
|
vals = dictfilter(args, self.dbfields, self.filter_nulls)
|
|
if self.id_name in vals and not self.allow_set_id:
|
|
del vals[self.id_name]
|
|
if self.html_quote:
|
|
quote_dict(vals) # quote all HTML markup
|
|
# format value
|
|
for title in vals.keys():
|
|
if title in self.input_formators:
|
|
vals[title] = self.input_formators[title](vals[title])
|
|
# insert
|
|
oid = DBInsertDict(cnx, self.table_name, vals, commit=True)
|
|
# get back new object id
|
|
cursor = cnx.cursor(cursor_factory=ScoDocCursor)
|
|
cursor.execute(
|
|
"select %(id_name)s from %(table_name)s where oid=%(oid)s"
|
|
% {"id_name": self.id_name, "table_name": self.table_name, "oid": oid}
|
|
)
|
|
new_id = cursor.fetchone()[0]
|
|
if has_uniq_values: # XXX probably obsolete
|
|
# check all tuples (without id_name) are UNIQUE !
|
|
res = DBSelectArgs(cnx, self.table_name, vals, what=[self.id_name])
|
|
if len(res) != 1:
|
|
# BUG !
|
|
log("create: BUG table_name=%s args=%s" % (self.table_name, str(args)))
|
|
assert len(res) == 1, "len(res) = %d != 1 !" % len(res)
|
|
if self.callback_on_write:
|
|
self.callback_on_write()
|
|
|
|
return new_id
|
|
|
|
def delete(self, cnx, oid, commit=True):
|
|
"delete tuple"
|
|
DBDelete(cnx, self.table_name, self.id_name, oid, commit=commit)
|
|
if self.callback_on_write:
|
|
self.callback_on_write()
|
|
|
|
def list(
|
|
self,
|
|
cnx,
|
|
args={},
|
|
operator="and",
|
|
test="=",
|
|
sortkey=None,
|
|
disable_formatting=False,
|
|
limit="",
|
|
offset="",
|
|
):
|
|
"returns list of dicts"
|
|
# REQLOG.write('%s: %s by %s (%s) %d\n'%(self.table_name,args,sys._getframe(1).f_code.co_name, sys._getframe(2).f_code.co_name, REQN))
|
|
# REQLOG.flush()
|
|
# global REQN
|
|
# REQN = REQN + 1
|
|
vals = dictfilter(args, self.dbfields, self.filter_nulls)
|
|
if not sortkey:
|
|
sortkey = self.sortkey
|
|
res = DBSelectArgs(
|
|
cnx,
|
|
self.table_name,
|
|
vals,
|
|
sortkey=sortkey,
|
|
test=test,
|
|
operator=operator,
|
|
aux_tables=self.aux_tables,
|
|
id_name=self.id_name,
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
for r in res:
|
|
self.format_output(r, disable_formatting=disable_formatting)
|
|
return res
|
|
|
|
def format_output(self, r, disable_formatting=False):
|
|
"Format dict using provided output_formators"
|
|
for (k, v) in r.items():
|
|
if v is None and self.convert_null_outputs_to_empty:
|
|
v = ""
|
|
# format value
|
|
if not disable_formatting and k in self.output_formators:
|
|
try: # XXX debug "isodate"
|
|
v = self.output_formators[k](v)
|
|
except:
|
|
log("*** list: vars=%s" % str(vars()))
|
|
log("*** list: r=%s" % str(r))
|
|
raise
|
|
r[k] = v
|
|
|
|
def edit(self, cnx, args, html_quote=None):
|
|
"""Change fields"""
|
|
vals = dictfilter(args, self.dbfields, self.filter_nulls)
|
|
html_quote = html_quote or self.html_quote
|
|
if html_quote:
|
|
quote_dict(vals) # quote HTML
|
|
# format value
|
|
for title in vals.keys():
|
|
if title in self.input_formators:
|
|
try:
|
|
vals[title] = self.input_formators[title](vals[title])
|
|
except:
|
|
log("exception while converting %s=%s" % (title, vals[title]))
|
|
raise
|
|
DBUpdateArgs(
|
|
cnx,
|
|
self.table_name,
|
|
vals,
|
|
where="%s=%%(%s)s" % (self.id_name, self.id_name),
|
|
commit=True,
|
|
)
|
|
if self.callback_on_write:
|
|
self.callback_on_write()
|
|
|
|
def get_sql_default_values(self, cnx):
|
|
"return dict with SQL default values for each field"
|
|
if self.sql_default_values is None: # not cached
|
|
# We insert a new tuple, get the values and delete it
|
|
# XXX non, car certaines tables ne peuvent creer de tuples
|
|
# a default, a cause des references ou contraintes d'intégrité.
|
|
# oid = self.create(cnx, {})
|
|
# vals = self.list(cnx, args= {self.id_name : oid})[0]
|
|
# self.delete(cnx, oid)
|
|
# self.sql_default_values = vals
|
|
#
|
|
# Méthode spécifique à postgresql (>= 7.4)
|
|
cursor = cnx.cursor(cursor_factory=ScoDocCursor)
|
|
cursor.execute(
|
|
"SELECT column_name, data_type, column_default FROM information_schema.columns WHERE table_name = '%s'"
|
|
% self.table_name
|
|
)
|
|
d = {}
|
|
for info in cursor.dictfetchall():
|
|
v = info["column_default"]
|
|
# strip type information if present (eg 'hello'::text)
|
|
if v:
|
|
v = v.split("::")[0]
|
|
# convert type to Python type
|
|
if v:
|
|
if info["data_type"] == "text":
|
|
log('get: v="%s"' % v)
|
|
if v[0] == v[-1] == "'":
|
|
v = v[1:-1] # strip quotes
|
|
elif v[:2] == "E'" and v[-1] == "'":
|
|
v = v[2:-1] # postgresql string with escapes
|
|
v = v.replace(
|
|
"\\012", "\n"
|
|
) # fix (je ne comprend pas bien pourquoi les valeurs sont ici quotées, ce n'est pas le cas dans les tables ordinaires)
|
|
v = v.replace("''", "'") # idem
|
|
log('--> v="%s"' % v)
|
|
elif info["data_type"] == "real":
|
|
v = float(v)
|
|
elif info["data_type"] == "integer":
|
|
v = int(v)
|
|
# elif info['data_type'] == 'date':
|
|
# pass # XXX todo
|
|
else:
|
|
log("Warning: unhandled SQL type in get_sql_default_values")
|
|
d[info["column_name"]] = v
|
|
self.sql_default_values = d
|
|
return self.sql_default_values
|
|
|
|
|
|
def dictfilter(d, fields, filter_nulls=True):
|
|
# returns a copy of d with only keys listed in "fields" and non null values
|
|
r = {}
|
|
for f in fields:
|
|
if f in d and (d[f] != None or not filter_nulls):
|
|
try:
|
|
val = d[f].strip()
|
|
except:
|
|
val = d[f]
|
|
# if val != '': not a good idea: how to suppress a field ?
|
|
r[f] = val
|
|
return r
|
|
|
|
|
|
# --------------------------------------------------------------------
|
|
# --- Misc Tools
|
|
|
|
|
|
def DateDMYtoISO(dmy, null_is_empty=False):
|
|
"convert date string from french format to ISO"
|
|
if not dmy:
|
|
if null_is_empty:
|
|
return ""
|
|
else:
|
|
return None
|
|
if type(dmy) != StringType:
|
|
return dmy.strftime("%Y-%m-%d")
|
|
|
|
t = dmy.split("/")
|
|
|
|
if len(t) != 3:
|
|
raise ScoValueError('Format de date (j/m/a) invalide: "%s"' % str(dmy))
|
|
day, month, year = t
|
|
year = int(year)
|
|
month = int(month)
|
|
day = int(day)
|
|
# accept years YYYY or YY, uses 1970 as pivot
|
|
if year < 100:
|
|
if year > 70:
|
|
year += 1900
|
|
else:
|
|
year += 2000
|
|
|
|
if month < 1 or month > 12:
|
|
raise ScoValueError("mois de la date invalide ! (%s)" % month)
|
|
# compute nb of day in month:
|
|
mo = month
|
|
if mo > 7:
|
|
mo = mo + 1
|
|
if mo % 2:
|
|
MonthNbDays = 31
|
|
elif mo == 2:
|
|
if year % 4 == 0 and ((year % 100 != 0) or (year % 400 == 0)):
|
|
MonthNbDays = 29 # leap
|
|
else:
|
|
MonthNbDays = 28
|
|
else:
|
|
MonthNbDays = 30
|
|
if day < 1 or day > MonthNbDays:
|
|
raise ScoValueError("jour de la date invalide ! (%s)" % day)
|
|
return "%04d-%02d-%02d" % (year, month, day)
|
|
|
|
|
|
def DateISOtoDMY(isodate):
|
|
if not isodate:
|
|
return ""
|
|
arg = isodate # debug
|
|
# si isodate est une instance de DateTime !
|
|
try:
|
|
isodate = "%s-%s-%s" % (isodate.year(), isodate.month(), isodate.day())
|
|
# log('DateISOtoDMY: converted isodate to iso !')
|
|
except:
|
|
pass
|
|
# drop time from isodate and split
|
|
t = str(isodate).split()[0].split("-")
|
|
if len(t) != 3:
|
|
# XXX recherche bug intermittent assez etrange
|
|
log('*** DateISOtoDMY: invalid isodate "%s" (arg="%s")' % (str(isodate), arg))
|
|
raise NoteProcessError(
|
|
'invalid isodate: "%s" (arg="%s" type=%s)' % (str(isodate), arg, type(arg))
|
|
)
|
|
year, month, day = t
|
|
year = int(year)
|
|
month = int(month)
|
|
day = int(day)
|
|
if month < 1 or month > 12:
|
|
raise ValueError("invalid month")
|
|
if day < 1 or day > 31:
|
|
raise ValueError("invalid day")
|
|
return "%02d/%02d/%04d" % (day, month, year)
|
|
|
|
|
|
def TimetoISO8601(t, null_is_empty=False):
|
|
"convert time string to ISO 8601 (allow 16:03, 16h03, 16)"
|
|
if isinstance(t, datetime.time):
|
|
return t.isoformat()
|
|
if not t and null_is_empty:
|
|
return ""
|
|
t = t.strip().upper().replace("H", ":")
|
|
if t and t.count(":") == 0 and len(t) < 3:
|
|
t = t + ":00"
|
|
return t
|
|
|
|
|
|
def TimefromISO8601(t):
|
|
"convert time string from ISO 8601 to our display format"
|
|
if not t:
|
|
return t
|
|
# XXX strange bug turnaround...
|
|
try:
|
|
t = "%s:%s" % (t.hour(), t.minute())
|
|
# log('TimefromISO8601: converted isotime to iso !')
|
|
except:
|
|
pass
|
|
fs = str(t).split(":")
|
|
return fs[0] + "h" + fs[1] # discard seconds
|
|
|
|
|
|
def TimeDuration(heure_debut, heure_fin):
|
|
"""duree (nb entier de minutes) entre deux heures a notre format
|
|
ie 12h23
|
|
"""
|
|
if heure_debut and heure_fin:
|
|
h0, m0 = [int(x) for x in heure_debut.split("h")]
|
|
h1, m1 = [int(x) for x in heure_fin.split("h")]
|
|
d = (h1 - h0) * 60 + (m1 - m0)
|
|
return d
|
|
else:
|
|
return None
|
|
|
|
|
|
def float_null_is_zero(x):
|
|
if x is None or x == "":
|
|
return 0.0
|
|
else:
|
|
return float(x)
|
|
|
|
|
|
def int_null_is_zero(x):
|
|
if x is None or x == "":
|
|
return 0
|
|
else:
|
|
return int(x)
|
|
|
|
|
|
def int_null_is_null(x):
|
|
if x is None or x == "":
|
|
return None
|
|
else:
|
|
return int(x)
|
|
|
|
|
|
def float_null_is_null(x):
|
|
if x is None or x == "":
|
|
return None
|
|
else:
|
|
return float(x)
|
|
|
|
|
|
# post filtering
|
|
#
|
|
def UniqListofDicts(L, key):
|
|
"""L is a list of dicts.
|
|
Remove from L all items which share the same key/value
|
|
"""
|
|
# well, code is simpler than documentation:
|
|
d = {}
|
|
for item in L:
|
|
d[item[key]] = item
|
|
return list(d.values())
|
|
|
|
|
|
#
|
|
def copy_tuples_changing_attribute(
|
|
cnx, table, column, old_value, new_value, to_exclude=[]
|
|
):
|
|
"""Duplicate tuples in DB table, replacing column old_value by new_value
|
|
|
|
Will raise exception if violation of integerity constraint !
|
|
"""
|
|
cursor = cnx.cursor(cursor_factory=ScoDocCursor)
|
|
cursor.execute(
|
|
"select * from %s where %s=%%(old_value)s" % (table, column),
|
|
{"old_value": old_value},
|
|
)
|
|
res = cursor.dictfetchall()
|
|
for t in res:
|
|
t[column] = new_value
|
|
for c in to_exclude:
|
|
del t[c]
|
|
DBInsertDict(cnx, table, t, convert_empty_to_nulls=False)
|