Module pugsql.statement

Compiled SQL function objects.

Source code
"""
Compiled SQL function objects.
"""

import threading
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional

import sqlalchemy
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import BindParameter

from .exceptions import InvalidArgumentError

_locals = threading.local()

if TYPE_CHECKING:
    from .compiler import Module


class ArrayLiteral(object):
    def __init__(self, array):
        self.array = list(array)


@contextmanager
def _compile_context(multiparams, params):
    _locals.compile_context = {
        "multiparams": multiparams,
        "params": params,
    }
    try:
        yield
    finally:
        _locals.compile_context = None


@compiles(BindParameter)
def _visit_bindparam(element, compiler, **kw):
    cc = getattr(_locals, "compile_context", None)
    if cc:
        if _is_expanding_param(element, cc):
            element.expanding = True
    return compiler.visit_bindparam(element)


def _is_expanding_param(element, cc):
    if element.key not in cc["params"]:
        return False
    return isinstance(cc["params"][element.key], tuple)


class Result(object):
    def transform(self, r):
        raise NotImplementedError()

    @property
    def display_type(self) -> str:
        raise NotImplementedError()


class One(Result):
    def transform(self, r):
        row = r.first()
        if row:
            return {k: v for k, v in zip(r.keys(), row)}
        return None

    @property
    def display_type(self) -> str:
        return "row"


class Many(Result):
    def transform(self, r):
        ks = r.keys()
        return ({k: v for k, v in zip(ks, row)} for row in r.fetchall())

    @property
    def display_type(self) -> str:
        return "rows"


class Affected(Result):
    def transform(self, r):
        return r.rowcount

    @property
    def display_type(self) -> str:
        return "rowcount"


class Scalar(Result):
    def transform(self, r):
        row = r.first()
        if not row:
            return None
        return row[0]

    @property
    def display_type(self) -> str:
        return "scalar"


class Insert(Scalar):
    def transform(self, r):
        if hasattr(r, "lastrowid"):
            return r.lastrowid
        return super(Insert, self).transform(r)

    @property
    def display_type(self) -> str:
        return "insert"


class Raw(Result):
    def transform(self, r):
        return r

    @property
    def display_type(self) -> str:
        return "raw"


class Statement(object):
    def __init__(
        self,
        name: str,
        sql: str,
        doc: str,
        result: Result,
        filename: Optional[str] = None,
    ):
        self.filename = filename

        if not name:
            self._value_err("Statement must have a name.")

        if sql is None:
            self._value_err("Statement must have a SQL string.")
        sql = sql.strip()
        if not len(sql):
            self._value_err("SQL string cannot be empty.")

        if not result:
            self._value_err("Statement must have a result type.")

        self.name = name
        self.sql = sql
        self.doc = doc
        self.result = result
        self.filename = filename
        self._module = None
        self._text = sqlalchemy.sql.text(self.sql)

    def _value_err(self, msg):
        if self.filename:
            raise ValueError("%s In: %s" % (msg, self.filename))
        raise ValueError(msg)

    def set_module(self, module: "Module"):
        self._module = module

    def _assert_module(self) -> "Module":
        if self._module is None:
            raise RuntimeError(
                "This statement is not associated with a module"
            )
        return self._module

    def __call__(self, *multiparams, **params):
        module = self._assert_module()
        multiparams, params = self._convert_params(multiparams, params)
        self._validateMultiparams(params, multiparams)
        with _compile_context(multiparams, params):
            try:
                r = module._execute(self._text, *multiparams, **params)
            except AttributeError as e:
                if str(e) == "'tuple' object has no attribute 'keys'":
                    self._positionalArgError()
                raise
        return self.result.transform(r)

    def _validateMultiparams(self, params, multiparams):
        # try to catch some common usage mistakes
        if not len(multiparams):
            return

        if len(params):
            # currently never supported to pass both positional args and
            # keywords
            self._positionalArgError()

        for p in multiparams:
            # multiparams are allowed when they're tuples/rows/etc to be e.g.
            # inserted
            if not type(p) in {dict, list, set, tuple}:
                self._positionalArgError()

    def _positionalArgError(self):
        raise InvalidArgumentError(
            "Pass keyword arguments to statements (received "
            "positional arguments)."
        )

    def _convert_params(self, multiparams, params):
        def conv(x):
            if isinstance(x, set) or isinstance(x, list):
                return tuple(x)
            elif isinstance(x, ArrayLiteral):
                return x.array
            return x

        return (
            [conv(p) for p in multiparams],
            {k: conv(v) for k, v in params.items()},
        )

    def _param_names(self):
        def kfn(p):
            return self.sql.index(":" + p)

        return sorted(self._text._bindparams.keys(), key=kfn)

    def __str__(self):
        paramstr = ", ".join(["%s=None" % k for k in self._param_names()])
        return "pugsql.statement.Statement: %s(%s) :: %s" % (
            self.name,
            paramstr,
            self.result.display_type,
        )

    def __repr__(self):
        return str(self)

Classes

class Affected
Source code
class Affected(Result):
    def transform(self, r):
        return r.rowcount

    @property
    def display_type(self) -> str:
        return "rowcount"

Ancestors

Instance variables

var display_type
Source code
@property
def display_type(self) -> str:
    return "rowcount"

Methods

def transform(self, r)
Source code
def transform(self, r):
    return r.rowcount
class ArrayLiteral (array)
Source code
class ArrayLiteral(object):
    def __init__(self, array):
        self.array = list(array)
class Insert
Source code
class Insert(Scalar):
    def transform(self, r):
        if hasattr(r, "lastrowid"):
            return r.lastrowid
        return super(Insert, self).transform(r)

    @property
    def display_type(self) -> str:
        return "insert"

Ancestors

Instance variables

var display_type
Source code
@property
def display_type(self) -> str:
    return "insert"

Methods

def transform(self, r)
Source code
def transform(self, r):
    if hasattr(r, "lastrowid"):
        return r.lastrowid
    return super(Insert, self).transform(r)
class Many
Source code
class Many(Result):
    def transform(self, r):
        ks = r.keys()
        return ({k: v for k, v in zip(ks, row)} for row in r.fetchall())

    @property
    def display_type(self) -> str:
        return "rows"

Ancestors

Instance variables

var display_type
Source code
@property
def display_type(self) -> str:
    return "rows"

Methods

def transform(self, r)
Source code
def transform(self, r):
    ks = r.keys()
    return ({k: v for k, v in zip(ks, row)} for row in r.fetchall())
class One
Source code
class One(Result):
    def transform(self, r):
        row = r.first()
        if row:
            return {k: v for k, v in zip(r.keys(), row)}
        return None

    @property
    def display_type(self) -> str:
        return "row"

Ancestors

Instance variables

var display_type
Source code
@property
def display_type(self) -> str:
    return "row"

Methods

def transform(self, r)
Source code
def transform(self, r):
    row = r.first()
    if row:
        return {k: v for k, v in zip(r.keys(), row)}
    return None
class Raw
Source code
class Raw(Result):
    def transform(self, r):
        return r

    @property
    def display_type(self) -> str:
        return "raw"

Ancestors

Instance variables

var display_type
Source code
@property
def display_type(self) -> str:
    return "raw"

Methods

def transform(self, r)
Source code
def transform(self, r):
    return r
class Result
Source code
class Result(object):
    def transform(self, r):
        raise NotImplementedError()

    @property
    def display_type(self) -> str:
        raise NotImplementedError()

Subclasses

Instance variables

var display_type
Source code
@property
def display_type(self) -> str:
    raise NotImplementedError()

Methods

def transform(self, r)
Source code
def transform(self, r):
    raise NotImplementedError()
class Scalar
Source code
class Scalar(Result):
    def transform(self, r):
        row = r.first()
        if not row:
            return None
        return row[0]

    @property
    def display_type(self) -> str:
        return "scalar"

Ancestors

Subclasses

Instance variables

var display_type
Source code
@property
def display_type(self) -> str:
    return "scalar"

Methods

def transform(self, r)
Source code
def transform(self, r):
    row = r.first()
    if not row:
        return None
    return row[0]
class Statement (name: str, sql: str, doc: str, result: Result, filename: str | None = None)
Source code
class Statement(object):
    def __init__(
        self,
        name: str,
        sql: str,
        doc: str,
        result: Result,
        filename: Optional[str] = None,
    ):
        self.filename = filename

        if not name:
            self._value_err("Statement must have a name.")

        if sql is None:
            self._value_err("Statement must have a SQL string.")
        sql = sql.strip()
        if not len(sql):
            self._value_err("SQL string cannot be empty.")

        if not result:
            self._value_err("Statement must have a result type.")

        self.name = name
        self.sql = sql
        self.doc = doc
        self.result = result
        self.filename = filename
        self._module = None
        self._text = sqlalchemy.sql.text(self.sql)

    def _value_err(self, msg):
        if self.filename:
            raise ValueError("%s In: %s" % (msg, self.filename))
        raise ValueError(msg)

    def set_module(self, module: "Module"):
        self._module = module

    def _assert_module(self) -> "Module":
        if self._module is None:
            raise RuntimeError(
                "This statement is not associated with a module"
            )
        return self._module

    def __call__(self, *multiparams, **params):
        module = self._assert_module()
        multiparams, params = self._convert_params(multiparams, params)
        self._validateMultiparams(params, multiparams)
        with _compile_context(multiparams, params):
            try:
                r = module._execute(self._text, *multiparams, **params)
            except AttributeError as e:
                if str(e) == "'tuple' object has no attribute 'keys'":
                    self._positionalArgError()
                raise
        return self.result.transform(r)

    def _validateMultiparams(self, params, multiparams):
        # try to catch some common usage mistakes
        if not len(multiparams):
            return

        if len(params):
            # currently never supported to pass both positional args and
            # keywords
            self._positionalArgError()

        for p in multiparams:
            # multiparams are allowed when they're tuples/rows/etc to be e.g.
            # inserted
            if not type(p) in {dict, list, set, tuple}:
                self._positionalArgError()

    def _positionalArgError(self):
        raise InvalidArgumentError(
            "Pass keyword arguments to statements (received "
            "positional arguments)."
        )

    def _convert_params(self, multiparams, params):
        def conv(x):
            if isinstance(x, set) or isinstance(x, list):
                return tuple(x)
            elif isinstance(x, ArrayLiteral):
                return x.array
            return x

        return (
            [conv(p) for p in multiparams],
            {k: conv(v) for k, v in params.items()},
        )

    def _param_names(self):
        def kfn(p):
            return self.sql.index(":" + p)

        return sorted(self._text._bindparams.keys(), key=kfn)

    def __str__(self):
        paramstr = ", ".join(["%s=None" % k for k in self._param_names()])
        return "pugsql.statement.Statement: %s(%s) :: %s" % (
            self.name,
            paramstr,
            self.result.display_type,
        )

    def __repr__(self):
        return str(self)

Methods

def set_module(self, module: Module)
Source code
def set_module(self, module: "Module"):
    self._module = module