Module pugsql.statement
Compiled SQL function objects.
Source code
"""
Compiled SQL function objects.
"""
from .exceptions import InvalidArgumentError
from contextlib import contextmanager
import sqlalchemy
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import BindParameter
import threading
_locals = threading.local()
@contextmanager
def _compile_context(multiparams, params):
_locals.compile_context = {
'multiparams': multiparams,
'params': params,
}
try:
yield
finally:
_locals.compile_context = None
@compiles(BindParameter)
def _visit_bindparam(element, compiler, **kw):
cc = getattr(_locals, 'compile_context', None)
if cc:
if _is_expanding_param(element, cc):
element.expanding = True
return compiler.visit_bindparam(element)
def _is_expanding_param(element, cc):
if element.key not in cc['params']:
return False
return isinstance(cc['params'][element.key], (tuple, list))
class Result(object):
def transform(self, r):
raise NotImplementedError()
@property
def display_type(self):
raise NotImplementedError()
class One(Result):
def transform(self, r):
row = r.first()
if row:
return { k: v for k, v in zip(r.keys(), row) }
return None
@property
def display_type(self):
return 'row'
class Many(Result):
def transform(self, r):
ks = r.keys()
return ({ k: v for k, v in zip(ks, row)} for row in r.fetchall())
@property
def display_type(self):
return 'rows'
class Affected(Result):
def transform(self, r):
return r.rowcount
@property
def display_type(self):
return 'rowcount'
class Scalar(Result):
def transform(self, r):
row = r.first()
if not row:
return None
return row[0]
@property
def display_type(self):
return 'scalar'
class Insert(Scalar):
def transform(self, r):
if hasattr(r, 'lastrowid'):
return r.lastrowid
return super(Insert, self).transform(r)
@property
def display_type(self):
return 'insert'
class Raw(Result):
def transform(self, r):
return r
@property
def display_type(self):
return 'raw'
class Statement(object):
def __init__(self, name, sql, doc, result, filename=None):
self.filename = filename
if not name:
self._value_err('Statement must have a name.')
if sql is None:
self._value_err('Statement must have a SQL string.')
sql = sql.strip()
if not len(sql):
self._value_err('SQL string cannot be empty.')
if not result:
self._value_err('Statement must have a result type.')
self.name = name
self.sql = sql
self.doc = doc
self.result = result
self.filename = filename
self._module = None
self._text = sqlalchemy.sql.text(self.sql)
def _value_err(self, msg):
if self.filename:
raise ValueError('%s In: %s' % (msg, self.filename))
raise ValueError(msg)
def set_module(self, module):
self._module = module
def _assert_module(self):
if self._module is None:
raise RuntimeError(
'This statement is not associated with a module')
def __call__(self, *multiparams, **params):
self._assert_module()
multiparams, params = self._convert_params(multiparams, params)
self._validateMultiparams(params, multiparams)
with _compile_context(multiparams, params):
try:
r = self._module._execute(self._text, *multiparams, **params)
except AttributeError as e:
if str(e) == "'tuple' object has no attribute 'keys'":
self._positionalArgError()
raise
return self.result.transform(r)
def _validateMultiparams(self, params, multiparams):
# try to catch some common usage mistakes
if not len(multiparams):
return
if len(params):
# currently never supported to pass both positional args and
# keywords
self._positionalArgError()
for p in multiparams:
# multiparams are allowed when they're tuples/rows/etc to be e.g.
# inserted
if not type(p) in { dict, list, set }:
self._positionalArgError()
def _positionalArgError(self):
raise InvalidArgumentError(
'Pass keyword arguments to statements (received '
'positional arguments).')
def _convert_params(self, multiparams, params):
def conv(x):
if isinstance(x, set):
return tuple(x)
return x
return (
[conv(p) for p in multiparams],
{ k: conv(v) for k, v in params.items() })
def _param_names(self):
def kfn(p):
return self.sql.index(':' + p)
return sorted(self._text._bindparams.keys(), key=kfn)
def __str__(self):
paramstr = ', '.join(['%s=None' % k for k in self._param_names()])
return 'pugsql.statement.Statement: %s(%s) :: %s' % (
self.name, paramstr, self.result.display_type)
def __repr__(self):
return str(self)
Classes
class Affected
-
Source code
class Affected(Result): def transform(self, r): return r.rowcount @property def display_type(self): return 'rowcount'
Ancestors
Instance variables
var display_type
-
Source code
@property def display_type(self): return 'rowcount'
Methods
def transform(self, r)
-
Source code
def transform(self, r): return r.rowcount
class Insert
-
Source code
class Insert(Scalar): def transform(self, r): if hasattr(r, 'lastrowid'): return r.lastrowid return super(Insert, self).transform(r) @property def display_type(self): return 'insert'
Ancestors
Instance variables
var display_type
-
Source code
@property def display_type(self): return 'insert'
Methods
def transform(self, r)
-
Source code
def transform(self, r): if hasattr(r, 'lastrowid'): return r.lastrowid return super(Insert, self).transform(r)
class Many
-
Source code
class Many(Result): def transform(self, r): ks = r.keys() return ({ k: v for k, v in zip(ks, row)} for row in r.fetchall()) @property def display_type(self): return 'rows'
Ancestors
Instance variables
var display_type
-
Source code
@property def display_type(self): return 'rows'
Methods
def transform(self, r)
-
Source code
def transform(self, r): ks = r.keys() return ({ k: v for k, v in zip(ks, row)} for row in r.fetchall())
class One
-
Source code
class One(Result): def transform(self, r): row = r.first() if row: return { k: v for k, v in zip(r.keys(), row) } return None @property def display_type(self): return 'row'
Ancestors
Instance variables
var display_type
-
Source code
@property def display_type(self): return 'row'
Methods
def transform(self, r)
-
Source code
def transform(self, r): row = r.first() if row: return { k: v for k, v in zip(r.keys(), row) } return None
class Raw
-
Source code
class Raw(Result): def transform(self, r): return r @property def display_type(self): return 'raw'
Ancestors
Instance variables
var display_type
-
Source code
@property def display_type(self): return 'raw'
Methods
def transform(self, r)
-
Source code
def transform(self, r): return r
class Result
-
Source code
class Result(object): def transform(self, r): raise NotImplementedError() @property def display_type(self): raise NotImplementedError()
Subclasses
Instance variables
var display_type
-
Source code
@property def display_type(self): raise NotImplementedError()
Methods
def transform(self, r)
-
Source code
def transform(self, r): raise NotImplementedError()
class Scalar
-
Source code
class Scalar(Result): def transform(self, r): row = r.first() if not row: return None return row[0] @property def display_type(self): return 'scalar'
Ancestors
Subclasses
Instance variables
var display_type
-
Source code
@property def display_type(self): return 'scalar'
Methods
def transform(self, r)
-
Source code
def transform(self, r): row = r.first() if not row: return None return row[0]
class Statement (name, sql, doc, result, filename=None)
-
Source code
class Statement(object): def __init__(self, name, sql, doc, result, filename=None): self.filename = filename if not name: self._value_err('Statement must have a name.') if sql is None: self._value_err('Statement must have a SQL string.') sql = sql.strip() if not len(sql): self._value_err('SQL string cannot be empty.') if not result: self._value_err('Statement must have a result type.') self.name = name self.sql = sql self.doc = doc self.result = result self.filename = filename self._module = None self._text = sqlalchemy.sql.text(self.sql) def _value_err(self, msg): if self.filename: raise ValueError('%s In: %s' % (msg, self.filename)) raise ValueError(msg) def set_module(self, module): self._module = module def _assert_module(self): if self._module is None: raise RuntimeError( 'This statement is not associated with a module') def __call__(self, *multiparams, **params): self._assert_module() multiparams, params = self._convert_params(multiparams, params) self._validateMultiparams(params, multiparams) with _compile_context(multiparams, params): try: r = self._module._execute(self._text, *multiparams, **params) except AttributeError as e: if str(e) == "'tuple' object has no attribute 'keys'": self._positionalArgError() raise return self.result.transform(r) def _validateMultiparams(self, params, multiparams): # try to catch some common usage mistakes if not len(multiparams): return if len(params): # currently never supported to pass both positional args and # keywords self._positionalArgError() for p in multiparams: # multiparams are allowed when they're tuples/rows/etc to be e.g. # inserted if not type(p) in { dict, list, set }: self._positionalArgError() def _positionalArgError(self): raise InvalidArgumentError( 'Pass keyword arguments to statements (received ' 'positional arguments).') def _convert_params(self, multiparams, params): def conv(x): if isinstance(x, set): return tuple(x) return x return ( [conv(p) for p in multiparams], { k: conv(v) for k, v in params.items() }) def _param_names(self): def kfn(p): return self.sql.index(':' + p) return sorted(self._text._bindparams.keys(), key=kfn) def __str__(self): paramstr = ', '.join(['%s=None' % k for k in self._param_names()]) return 'pugsql.statement.Statement: %s(%s) :: %s' % ( self.name, paramstr, self.result.display_type) def __repr__(self): return str(self)
Methods
def set_module(self, module)
-
Source code
def set_module(self, module): self._module = module