Updated SqlAlchemy

This commit is contained in:
Christoffer Viken 2017-04-15 16:27:12 +00:00
parent 2c790e1fe1
commit e3267d4bda
59 changed files with 30236 additions and 26049 deletions

View File

@ -1,24 +1,23 @@
# __init__.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
# sqlalchemy/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import inspect
import sys
import sqlalchemy.exc as exceptions
sys.modules['sqlalchemy.exceptions'] = exceptions
from sqlalchemy.sql import (
from .sql import (
alias,
all_,
and_,
any_,
asc,
between,
bindparam,
case,
cast,
collate,
column,
delete,
desc,
distinct,
@ -26,11 +25,14 @@ from sqlalchemy.sql import (
except_all,
exists,
extract,
false,
func,
funcfilter,
insert,
intersect,
intersect_all,
join,
lateral,
literal,
literal_column,
modifier,
@ -39,16 +41,25 @@ from sqlalchemy.sql import (
or_,
outerjoin,
outparam,
over,
select,
subquery,
table,
tablesample,
text,
true,
tuple_,
type_coerce,
union,
union_all,
update,
within_group,
)
from sqlalchemy.types import (
from .types import (
ARRAY,
BIGINT,
BINARY,
BLOB,
BOOLEAN,
BigInteger,
@ -68,12 +79,14 @@ from sqlalchemy.types import (
INTEGER,
Integer,
Interval,
JSON,
LargeBinary,
NCHAR,
NVARCHAR,
NUMERIC,
Numeric,
PickleType,
REAL,
SMALLINT,
SmallInteger,
String,
@ -82,18 +95,19 @@ from sqlalchemy.types import (
TIMESTAMP,
Text,
Time,
TypeDecorator,
Unicode,
UnicodeText,
VARBINARY,
VARCHAR,
)
from sqlalchemy.schema import (
from .schema import (
CheckConstraint,
Column,
ColumnDefault,
Constraint,
DDL,
DefaultClause,
FetchedValue,
ForeignKey,
@ -106,14 +120,27 @@ from sqlalchemy.schema import (
Table,
ThreadLocalMetaData,
UniqueConstraint,
)
from sqlalchemy.engine import create_engine, engine_from_config
DDL,
BLANK_SCHEMA
)
__all__ = sorted(name for name, obj in locals().items()
if not (name.startswith('_') or inspect.ismodule(obj)))
__version__ = '0.6beta3'
from .inspection import inspect
from .engine import create_engine, engine_from_config
del inspect, sys
__version__ = '1.1.9'
def __go(lcls):
global __all__
from . import events
from . import util as _sa_util
import inspect as _inspect
__all__ = sorted(name for name, obj in lcls.items()
if not (name.startswith('_') or _inspect.ismodule(obj)))
_sa_util.dependencies.resolve_all("sqlalchemy")
__go(locals())

View File

@ -1,6 +1,10 @@
# connectors/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
class Connector(object):
pass

View File

@ -1,5 +1,12 @@
# connectors/mxodbc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
Provide an SQLALchemy connector for the eGenix mxODBC commercial
Provide a SQLALchemy connector for the eGenix mxODBC commercial
Python adapter for ODBC. This is not a free product, but eGenix
provides SQLAlchemy with a license for use in continuous integration
testing.
@ -15,21 +22,19 @@ For more info on mxODBC, see http://www.egenix.com/
import sys
import re
import warnings
from decimal import Decimal
from sqlalchemy.connectors import Connector
from sqlalchemy import types as sqltypes
import sqlalchemy.processors as processors
from . import Connector
class MxODBCConnector(Connector):
driver='mxodbc'
driver = 'mxodbc'
supports_sane_multi_rowcount = False
supports_unicode_statements = False
supports_unicode_binds = False
supports_unicode_statements = True
supports_unicode_binds = True
supports_native_decimal = True
@classmethod
def dbapi(cls):
# this classmethod will normally be replaced by an instance
@ -44,7 +49,7 @@ class MxODBCConnector(Connector):
elif platform == 'darwin':
from mx.ODBC import iODBC as module
else:
raise ImportError, "Unrecognized platform for mxODBC import"
raise ImportError("Unrecognized platform for mxODBC import")
return module
@classmethod
@ -64,21 +69,21 @@ class MxODBCConnector(Connector):
conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT
conn.errorhandler = self._error_handler()
return connect
def _error_handler(self):
""" Return a handler that adjusts mxODBC's raised Warnings to
emit Python standard warnings.
"""
from mx.ODBC.Error import Warning as MxOdbcWarning
def error_handler(connection, cursor, errorclass, errorvalue):
def error_handler(connection, cursor, errorclass, errorvalue):
if issubclass(errorclass, MxOdbcWarning):
errorclass.__bases__ = (Warning,)
warnings.warn(message=str(errorvalue),
category=errorclass,
stacklevel=2)
category=errorclass,
stacklevel=2)
else:
raise errorclass, errorvalue
raise errorclass(errorvalue)
return error_handler
def create_connect_args(self, url):
@ -94,7 +99,7 @@ class MxODBCConnector(Connector):
The arg 'errorhandler' is not used by SQLAlchemy and will
not be populated.
"""
opts = url.translate_connect_args(username='user')
opts.update(url.query)
@ -103,9 +108,9 @@ class MxODBCConnector(Connector):
opts.pop('database', None)
return (args,), opts
def is_disconnect(self, e):
# eGenix recommends checking connection.closed here,
# but how can we get a handle on the current connection?
def is_disconnect(self, e, connection, cursor):
# TODO: eGenix recommends checking connection.closed here
# Does that detect dropped connections ?
if isinstance(e, self.dbapi.ProgrammingError):
return "connection already closed" in str(e)
elif isinstance(e, self.dbapi.Error):
@ -114,10 +119,11 @@ class MxODBCConnector(Connector):
return False
def _get_server_version_info(self, connection):
# eGenix suggests using conn.dbms_version instead of what we're doing here
# eGenix suggests using conn.dbms_version instead
# of what we're doing here
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
r = re.compile(r'[.\-]')
# 18 == pyodbc.SQL_DBMS_VER
for n in r.split(dbapi_con.getinfo(18)[1]):
try:
@ -126,21 +132,19 @@ class MxODBCConnector(Connector):
version.append(n)
return tuple(version)
def do_execute(self, cursor, statement, parameters, context=None):
def _get_direct(self, context):
if context:
native_odbc_execute = context.execution_options.\
get('native_odbc_execute', 'auto')
if native_odbc_execute is True:
# user specified native_odbc_execute=True
cursor.execute(statement, parameters)
elif native_odbc_execute is False:
# user specified native_odbc_execute=False
cursor.executedirect(statement, parameters)
elif context.is_crud:
# statement is UPDATE, DELETE, INSERT
cursor.execute(statement, parameters)
else:
# all other statements
cursor.executedirect(statement, parameters)
get('native_odbc_execute', 'auto')
# default to direct=True in all cases, is more generally
# compatible especially with SQL Server
return False if native_odbc_execute is True else True
else:
cursor.executedirect(statement, parameters)
return True
def do_executemany(self, cursor, statement, parameters, context=None):
cursor.executemany(
statement, parameters, direct=self._get_direct(context))
def do_execute(self, cursor, statement, parameters, context=None):
cursor.execute(statement, parameters, direct=self._get_direct(context))

View File

@ -1,29 +1,51 @@
from sqlalchemy.connectors import Connector
from sqlalchemy.util import asbool
# connectors/pyodbc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from . import Connector
from .. import util
import sys
import re
import urllib
import decimal
class PyODBCConnector(Connector):
driver='pyodbc'
driver = 'pyodbc'
supports_sane_multi_rowcount = False
# PyODBC unicode is broken on UCS-4 builds
supports_unicode = sys.maxunicode == 65535
supports_unicode_statements = supports_unicode
if util.py2k:
# PyODBC unicode is broken on UCS-4 builds
supports_unicode = sys.maxunicode == 65535
supports_unicode_statements = supports_unicode
supports_native_decimal = True
default_paramstyle = 'named'
# for non-DSN connections, this should
# for non-DSN connections, this *may* be used to
# hold the desired driver name
pyodbc_driver_name = None
# will be set to True after initialize()
# if the freetds.so is detected
freetds = False
# will be set to the string version of
# the FreeTDS driver if freetds is detected
freetds_driver_version = None
# will be set to True after initialize()
# if the libessqlsrv.so is detected
easysoft = False
def __init__(self, supports_unicode_binds=None, **kw):
super(PyODBCConnector, self).__init__(**kw)
self._user_supports_unicode_binds = supports_unicode_binds
@classmethod
def dbapi(cls):
return __import__('pyodbc')
@ -31,29 +53,53 @@ class PyODBCConnector(Connector):
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
opts.update(url.query)
keys = opts
query = url.query
connect_args = {}
for param in ('ansi', 'unicode_results', 'autocommit'):
if param in keys:
connect_args[param] = asbool(keys.pop(param))
connect_args[param] = util.asbool(keys.pop(param))
if 'odbc_connect' in keys:
connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
connectors = [util.unquote_plus(keys.pop('odbc_connect'))]
else:
dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys)
def check_quote(token):
if ";" in str(token):
token = "'%s'" % token
return token
keys = dict(
(k, check_quote(v)) for k, v in keys.items()
)
dsn_connection = 'dsn' in keys or \
('host' in keys and 'database' not in keys)
if dsn_connection:
connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))]
connectors = ['dsn=%s' % (keys.pop('host', '') or
keys.pop('dsn', ''))]
else:
port = ''
if 'port' in keys and not 'port' in query:
if 'port' in keys and 'port' not in query:
port = ',%d' % int(keys.pop('port'))
connectors = ["DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name),
'Server=%s%s' % (keys.pop('host', ''), port),
'Database=%s' % keys.pop('database', '') ]
connectors = []
driver = keys.pop('driver', self.pyodbc_driver_name)
if driver is None:
util.warn(
"No driver name specified; "
"this is expected by PyODBC when using "
"DSN-less connections")
else:
connectors.append("DRIVER={%s}" % driver)
connectors.extend(
[
'Server=%s%s' % (keys.pop('host', ''), port),
'Database=%s' % keys.pop('database', '')
])
user = keys.pop("user", None)
if user:
@ -62,20 +108,22 @@ class PyODBCConnector(Connector):
else:
connectors.append("Trusted_Connection=Yes")
# if set to 'Yes', the ODBC layer will try to automagically convert
# textual data from your database encoding to your client encoding
# This should obviously be set to 'No' if you query a cp1253 encoded
# database from a latin1 client...
# if set to 'Yes', the ODBC layer will try to automagically
# convert textual data from your database encoding to your
# client encoding. This should obviously be set to 'No' if
# you query a cp1253 encoded database from a latin1 client...
if 'odbc_autotranslate' in keys:
connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
connectors.append("AutoTranslate=%s" %
keys.pop("odbc_autotranslate"))
connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
return [[";".join (connectors)], connect_args]
def is_disconnect(self, e):
connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()])
return [[";".join(connectors)], connect_args]
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.ProgrammingError):
return "The cursor's connection has been closed." in str(e) or \
'Attempt to use a closed connection.' in str(e)
'Attempt to use a closed connection.' in str(e)
elif isinstance(e, self.dbapi.Error):
return '[08S01]' in str(e)
else:
@ -84,27 +132,62 @@ class PyODBCConnector(Connector):
def initialize(self, connection):
# determine FreeTDS first. can't issue SQL easily
# without getting unicode_statements/binds set up.
pyodbc = self.dbapi
dbapi_con = connection.connection
self.freetds = bool(re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)))
_sql_driver_name = dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)
self.freetds = bool(re.match(r".*libtdsodbc.*\.so", _sql_driver_name
))
self.easysoft = bool(re.match(r".*libessqlsrv.*\.so", _sql_driver_name
))
if self.freetds:
self.freetds_driver_version = dbapi_con.getinfo(
pyodbc.SQL_DRIVER_VER)
self.supports_unicode_statements = (
not util.py2k or
(not self.freetds and not self.easysoft)
)
if self._user_supports_unicode_binds is not None:
self.supports_unicode_binds = self._user_supports_unicode_binds
elif util.py2k:
self.supports_unicode_binds = (
not self.freetds or self.freetds_driver_version >= '0.91'
) and not self.easysoft
else:
self.supports_unicode_binds = True
# the "Py2K only" part here is theoretical.
# have not tried pyodbc + python3.1 yet.
# Py2K
self.supports_unicode_statements = not self.freetds
self.supports_unicode_binds = not self.freetds
# end Py2K
# run other initialization which asks for user name, etc.
super(PyODBCConnector, self).initialize(connection)
def _dbapi_version(self):
if not self.dbapi:
return ()
return self._parse_dbapi_version(self.dbapi.version)
def _parse_dbapi_version(self, vers):
m = re.match(
r'(?:py.*-)?([\d\.]+)(?:-(\w+))?',
vers
)
if not m:
return ()
vers = tuple([int(x) for x in m.group(1).split(".")])
if m.group(2):
vers += (m.group(2),)
return vers
def _get_server_version_info(self, connection):
# NOTE: this function is not reliable, particularly when
# freetds is in use. Implement database-specific server version
# queries.
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
r = re.compile(r'[.\-]')
for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
try:
version.append(int(n))

View File

@ -1,20 +1,28 @@
# connectors/zxJDBC.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sys
from sqlalchemy.connectors import Connector
from . import Connector
class ZxJDBCConnector(Connector):
driver = 'zxjdbc'
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
supports_unicode_binds = True
supports_unicode_statements = sys.version > '2.5.0+'
description_encoding = None
default_paramstyle = 'qmark'
jdbc_db_name = None
jdbc_driver_name = None
@classmethod
def dbapi(cls):
from com.ziclix.python.sql import zxJDBC
@ -23,20 +31,24 @@ class ZxJDBCConnector(Connector):
def _driver_kwargs(self):
"""Return kw arg dict to be sent to connect()."""
return {}
def _create_jdbc_url(self, url):
"""Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`"""
return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host,
url.port is not None and ':%s' % url.port or '',
url.port is not None
and ':%s' % url.port or '',
url.database)
def create_connect_args(self, url):
opts = self._driver_kwargs()
opts.update(url.query)
return [[self._create_jdbc_url(url), url.username, url.password, self.jdbc_driver_name],
opts]
return [
[self._create_jdbc_url(url),
url.username, url.password,
self.jdbc_driver_name],
opts]
def is_disconnect(self, e):
def is_disconnect(self, e, connection, cursor):
if not isinstance(e, self.dbapi.ProgrammingError):
return False
e = str(e)

View File

@ -1,12 +1,56 @@
# dialects/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
__all__ = (
# 'access',
# 'firebird',
# 'informix',
# 'maxdb',
# 'mssql',
'firebird',
'mssql',
'mysql',
'oracle',
'postgresql',
'sqlite',
# 'sybase',
)
'sybase',
)
from .. import util
_translates = {'postgres': 'postgresql'}
def _auto_fn(name):
"""default dialect importer.
plugs into the :class:`.PluginLoader`
as a first-hit system.
"""
if "." in name:
dialect, driver = name.split(".")
else:
dialect = name
driver = "base"
if dialect in _translates:
translated = _translates[dialect]
util.warn_deprecated(
"The '%s' dialect name has been "
"renamed to '%s'" % (dialect, translated)
)
dialect = translated
try:
module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
except ImportError:
return None
module = getattr(module, dialect)
if hasattr(module, driver):
module = getattr(module, driver)
return lambda: module.dialect
else:
return None
registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)
plugins = util.PluginLoader("sqlalchemy.plugins")

View File

@ -1,14 +1,36 @@
from sqlalchemy.dialects.postgresql import base, psycopg2, pg8000, pypostgresql, zxjdbc
# postgresql/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from . import base, psycopg2, pg8000, pypostgresql, pygresql, \
zxjdbc, psycopg2cffi
base.dialect = psycopg2.dialect
from sqlalchemy.dialects.postgresql.base import \
INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, INET, \
CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME,\
DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect
from .base import \
INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \
INET, CIDR, UUID, BIT, MACADDR, OID, DOUBLE_PRECISION, TIMESTAMP, TIME, \
DATE, BYTEA, BOOLEAN, INTERVAL, ENUM, dialect, TSVECTOR, DropEnumType, \
CreateEnumType
from .hstore import HSTORE, hstore
from .json import JSON, JSONB
from .array import array, ARRAY, Any, All
from .ext import aggregate_order_by, ExcludeConstraint, array_agg
from .dml import insert, Insert
from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \
TSTZRANGE
__all__ = (
'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', 'FLOAT', 'REAL', 'INET',
'CIDR', 'UUID', 'BIT', 'MACADDR', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME',
'DATE', 'BYTEA', 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect'
'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC',
'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR', 'OID',
'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN',
'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'array', 'HSTORE',
'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE',
'TSRANGE', 'TSTZRANGE', 'json', 'JSON', 'JSONB', 'Any', 'All',
'DropEnumType', 'CreateEnumType', 'ExcludeConstraint',
'aggregate_order_by', 'array_agg', 'insert', 'Insert'
)

File diff suppressed because it is too large Load Diff

View File

@ -1,63 +1,130 @@
"""Support for the PostgreSQL database via the pg8000 driver.
# postgresql/pg8000.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
Connecting
----------
"""
.. dialect:: postgresql+pg8000
:name: pg8000
:dbapi: pg8000
:connectstring: \
postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]
:url: https://pythonhosted.org/pg8000/
URLs are of the form
`postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]`.
.. _pg8000_unicode:
Unicode
-------
pg8000 requires that the postgresql client encoding be configured in the postgresql.conf file
in order to use encodings other than ascii. Set this value to the same value as
the "encoding" parameter on create_engine(), usually "utf-8".
pg8000 will encode / decode string values between it and the server using the
PostgreSQL ``client_encoding`` parameter; by default this is the value in
the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
Typically, this can be changed to ``utf-8``, as a more useful default::
Interval
--------
#client_encoding = sql_ascii # actually, defaults to database
# encoding
client_encoding = utf8
The ``client_encoding`` can be overridden for a session by executing the SQL:
SET CLIENT_ENCODING TO 'utf8';
SQLAlchemy will execute this SQL on all new connections based on the value
passed to :func:`.create_engine` using the ``client_encoding`` parameter::
engine = create_engine(
"postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
.. _pg8000_isolation_level:
pg8000 Transaction Isolation Level
-------------------------------------
The pg8000 dialect offers the same isolation level settings as that
of the :ref:`psycopg2 <psycopg2_isolation_level>` dialect:
* ``READ COMMITTED``
* ``READ UNCOMMITTED``
* ``REPEATABLE READ``
* ``SERIALIZABLE``
* ``AUTOCOMMIT``
.. versionadded:: 0.9.5 support for AUTOCOMMIT isolation level when using
pg8000.
.. seealso::
:ref:`postgresql_isolation_level`
:ref:`psycopg2_isolation_level`
Passing data from/to the Interval type is not supported as of yet.
"""
from ... import util, exc
import decimal
from ... import processors
from ... import types as sqltypes
from .base import (
PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext,
_DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES)
import re
from sqlalchemy.dialects.postgresql.json import JSON
from sqlalchemy.engine import default
from sqlalchemy import util, exc
from sqlalchemy import processors
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql.base import PGDialect, \
PGCompiler, PGIdentifierPreparer, PGExecutionContext
class _PGNumeric(sqltypes.Numeric):
def result_processor(self, dialect, coltype):
if self.asdecimal:
if coltype in (700, 701):
return processors.to_decimal_processor_factory(decimal.Decimal)
elif coltype == 1700:
if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
decimal.Decimal, self._effective_decimal_return_scale)
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# pg8000 returns Decimal natively for 1700
return None
else:
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
else:
if coltype in (700, 701):
if coltype in _FLOAT_TYPES:
# pg8000 returns float natively for 701
return None
elif coltype == 1700:
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
return processors.to_float
else:
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
class _PGNumericNoBind(_PGNumeric):
def bind_processor(self, dialect):
return None
class _PGJSON(JSON):
def result_processor(self, dialect, coltype):
if dialect._dbapi_version > (1, 10, 1):
return None # Has native JSON
else:
return super(_PGJSON, self).result_processor(dialect, coltype)
class PGExecutionContext_pg8000(PGExecutionContext):
pass
class PGCompiler_pg8000(PGCompiler):
def visit_mod(self, binary, **kw):
return self.process(binary.left) + " %% " + self.process(binary.right)
def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)
def post_process_text(self, text):
if '%%' in text:
util.warn("The SQLAlchemy postgresql dialect now automatically escapes '%' in text() "
util.warn("The SQLAlchemy postgresql dialect "
"now automatically escapes '%' in text() "
"expressions to '%%'.")
return text.replace('%', '%%')
@ -67,30 +134,52 @@ class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace('%', '%%')
class PGDialect_pg8000(PGDialect):
driver = 'pg8000'
supports_unicode_statements = True
supports_unicode_binds = True
default_paramstyle = 'format'
supports_sane_multi_rowcount = False
supports_sane_multi_rowcount = True
execution_ctx_cls = PGExecutionContext_pg8000
statement_compiler = PGCompiler_pg8000
preparer = PGIdentifierPreparer_pg8000
description_encoding = 'use_encoding'
colspecs = util.update_copy(
PGDialect.colspecs,
{
sqltypes.Numeric : _PGNumeric,
sqltypes.Numeric: _PGNumericNoBind,
sqltypes.Float: _PGNumeric,
JSON: _PGJSON,
sqltypes.JSON: _PGJSON
}
)
def __init__(self, client_encoding=None, **kwargs):
PGDialect.__init__(self, **kwargs)
self.client_encoding = client_encoding
def initialize(self, connection):
self.supports_sane_multi_rowcount = self._dbapi_version >= (1, 9, 14)
super(PGDialect_pg8000, self).initialize(connection)
@util.memoized_property
def _dbapi_version(self):
if self.dbapi and hasattr(self.dbapi, '__version__'):
return tuple(
[
int(x) for x in re.findall(
r'(\d+)(?:[-\.]?|$)', self.dbapi.__version__)])
else:
return (99, 99, 99)
@classmethod
def dbapi(cls):
return __import__('pg8000').dbapi
return __import__('pg8000')
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
@ -99,7 +188,78 @@ class PGDialect_pg8000(PGDialect):
opts.update(url.query)
return ([], opts)
def is_disconnect(self, e):
def is_disconnect(self, e, connection, cursor):
return "connection is closed" in str(e)
def set_isolation_level(self, connection, level):
level = level.replace('_', ' ')
# adjust for ConnectionFairy possibly being present
if hasattr(connection, 'connection'):
connection = connection.connection
if level == 'AUTOCOMMIT':
connection.autocommit = True
elif level in self._isolation_lookup:
connection.autocommit = False
cursor = connection.cursor()
cursor.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION "
"ISOLATION LEVEL %s" % level)
cursor.execute("COMMIT")
cursor.close()
else:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
"Valid isolation levels for %s are %s or AUTOCOMMIT" %
(level, self.name, ", ".join(self._isolation_lookup))
)
def set_client_encoding(self, connection, client_encoding):
# adjust for ConnectionFairy possibly being present
if hasattr(connection, 'connection'):
connection = connection.connection
cursor = connection.cursor()
cursor.execute("SET CLIENT_ENCODING TO '" + client_encoding + "'")
cursor.execute("COMMIT")
cursor.close()
def do_begin_twophase(self, connection, xid):
connection.connection.tpc_begin((0, xid, ''))
def do_prepare_twophase(self, connection, xid):
connection.connection.tpc_prepare()
def do_rollback_twophase(
self, connection, xid, is_prepared=True, recover=False):
connection.connection.tpc_rollback((0, xid, ''))
def do_commit_twophase(
self, connection, xid, is_prepared=True, recover=False):
connection.connection.tpc_commit((0, xid, ''))
def do_recover_twophase(self, connection):
return [row[1] for row in connection.connection.tpc_recover()]
def on_connect(self):
fns = []
if self.client_encoding is not None:
def on_connect(conn):
self.set_client_encoding(conn, self.client_encoding)
fns.append(on_connect)
if self.isolation_level is not None:
def on_connect(conn):
self.set_isolation_level(conn, self.isolation_level)
fns.append(on_connect)
if len(fns) > 0:
def on_connect(conn):
for fn in fns:
fn(conn)
return on_connect
else:
return None
dialect = PGDialect_pg8000

View File

@ -1,74 +1,335 @@
"""Support for the PostgreSQL database via the psycopg2 driver.
# postgresql/psycopg2.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
Driver
------
"""
.. dialect:: postgresql+psycopg2
:name: psycopg2
:dbapi: psycopg2
:connectstring: postgresql+psycopg2://user:password@host:port/dbname\
[?key=value&key=value...]
:url: http://pypi.python.org/pypi/psycopg2/
The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
The dialect has several behaviors which are specifically tailored towards compatibility
with this module.
psycopg2 Connect Arguments
-----------------------------------
Note that psycopg1 is **not** supported.
psycopg2-specific keyword arguments which are accepted by
:func:`.create_engine()` are:
Connecting
----------
* ``server_side_cursors``: Enable the usage of "server side cursors" for SQL
statements which support this feature. What this essentially means from a
psycopg2 point of view is that the cursor is created using a name, e.g.
``connection.cursor('some name')``, which has the effect that result rows
are not immediately pre-fetched and buffered after statement execution, but
are instead left on the server and only retrieved as needed. SQLAlchemy's
:class:`~sqlalchemy.engine.ResultProxy` uses special row-buffering
behavior when this feature is enabled, such that groups of 100 rows at a
time are fetched over the wire to reduce conversational overhead.
Note that the :paramref:`.Connection.execution_options.stream_results`
execution option is a more targeted
way of enabling this mode on a per-execution basis.
* ``use_native_unicode``: Enable the usage of Psycopg2 "native unicode" mode
per connection. True by default.
URLs are of the form `postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]`.
.. seealso::
psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
:ref:`psycopg2_disable_native_unicode`
* ``isolation_level``: This option, available for all PostgreSQL dialects,
includes the ``AUTOCOMMIT`` isolation level when using the psycopg2
dialect.
.. seealso::
:ref:`psycopg2_isolation_level`
* ``client_encoding``: sets the client encoding in a libpq-agnostic way,
using psycopg2's ``set_client_encoding()`` method.
.. seealso::
:ref:`psycopg2_unicode`
Unix Domain Connections
------------------------
psycopg2 supports connecting via Unix domain connections. When the ``host``
portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2,
which specifies Unix-domain communication rather than TCP/IP communication::
create_engine("postgresql+psycopg2://user:password@/dbname")
By default, the socket file used is to connect to a Unix-domain socket
in ``/tmp``, or whatever socket directory was specified when PostgreSQL
was built. This value can be overridden by passing a pathname to psycopg2,
using ``host`` as an additional keyword argument::
create_engine("postgresql+psycopg2://user:password@/dbname?\
host=/var/lib/postgresql")
See also:
`PQconnectdbParams <http://www.postgresql.org/docs/9.1/static/\
libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS>`_
.. _psycopg2_execution_options:
Per-Statement/Connection Execution Options
-------------------------------------------
The following DBAPI-specific options are respected when used with
:meth:`.Connection.execution_options`, :meth:`.Executable.execution_options`,
:meth:`.Query.execution_options`, in addition to those not specific to DBAPIs:
* ``isolation_level`` - Set the transaction isolation level for the lifespan of a
:class:`.Connection` (can only be set on a connection, not a statement
or query). See :ref:`psycopg2_isolation_level`.
* ``stream_results`` - Enable or disable usage of psycopg2 server side cursors -
this feature makes use of "named" cursors in combination with special
result handling methods so that result rows are not fully buffered.
If ``None`` or not set, the ``server_side_cursors`` option of the
:class:`.Engine` is used.
* ``max_row_buffer`` - when using ``stream_results``, an integer value that
specifies the maximum number of rows to buffer at a time. This is
interpreted by the :class:`.BufferedRowResultProxy`, and if omitted the
buffer will grow to ultimately store 1000 rows at a time.
.. versionadded:: 1.0.6
.. _psycopg2_unicode:
Unicode with Psycopg2
----------------------
By default, the psycopg2 driver uses the ``psycopg2.extensions.UNICODE``
extension, such that the DBAPI receives and returns all strings as Python
Unicode objects directly - SQLAlchemy passes these values through without
change. Psycopg2 here will encode/decode string values based on the
current "client encoding" setting; by default this is the value in
the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
Typically, this can be changed to ``utf8``, as a more useful default::
# postgresql.conf file
# client_encoding = sql_ascii # actually, defaults to database
# encoding
client_encoding = utf8
A second way to affect the client encoding is to set it within Psycopg2
locally. SQLAlchemy will call psycopg2's
:meth:`psycopg2:connection.set_client_encoding` method
on all new connections based on the value passed to
:func:`.create_engine` using the ``client_encoding`` parameter::
# set_client_encoding() setting;
# works for *all* PostgreSQL versions
engine = create_engine("postgresql://user:pass@host/dbname",
client_encoding='utf8')
This overrides the encoding specified in the PostgreSQL client configuration.
When using the parameter in this way, the psycopg2 driver emits
``SET client_encoding TO 'utf8'`` on the connection explicitly, and works
in all PostgreSQL versions.
Note that the ``client_encoding`` setting as passed to :func:`.create_engine`
is **not the same** as the more recently added ``client_encoding`` parameter
now supported by libpq directly. This is enabled when ``client_encoding``
is passed directly to ``psycopg2.connect()``, and from SQLAlchemy is passed
using the :paramref:`.create_engine.connect_args` parameter::
# libpq direct parameter setting;
# only works for PostgreSQL **9.1 and above**
engine = create_engine("postgresql://user:pass@host/dbname",
connect_args={'client_encoding': 'utf8'})
# using the query string is equivalent
engine = create_engine("postgresql://user:pass@host/dbname?client_encoding=utf8")
The above parameter was only added to libpq as of version 9.1 of PostgreSQL,
so using the previous method is better for cross-version support.
.. _psycopg2_disable_native_unicode:
Disabling Native Unicode
^^^^^^^^^^^^^^^^^^^^^^^^
SQLAlchemy can also be instructed to skip the usage of the psycopg2
``UNICODE`` extension and to instead utilize its own unicode encode/decode
services, which are normally reserved only for those DBAPIs that don't
fully support unicode directly. Passing ``use_native_unicode=False`` to
:func:`.create_engine` will disable usage of ``psycopg2.extensions.UNICODE``.
SQLAlchemy will instead encode data itself into Python bytestrings on the way
in and coerce from bytes on the way back,
using the value of the :func:`.create_engine` ``encoding`` parameter, which
defaults to ``utf-8``.
SQLAlchemy's own unicode encode/decode functionality is steadily becoming
obsolete as most DBAPIs now support unicode fully.
Bound Parameter Styles
----------------------
The default parameter style for the psycopg2 dialect is "pyformat", where
SQL is rendered using ``%(paramname)s`` style. This format has the limitation
that it does not accommodate the unusual case of parameter names that
actually contain percent or parenthesis symbols; as SQLAlchemy in many cases
generates bound parameter names based on the name of a column, the presence
of these characters in a column name can lead to problems.
There are two solutions to the issue of a :class:`.schema.Column` that contains
one of these characters in its name. One is to specify the
:paramref:`.schema.Column.key` for columns that have such names::
measurement = Table('measurement', metadata,
Column('Size (meters)', Integer, key='size_meters')
)
Above, an INSERT statement such as ``measurement.insert()`` will use
``size_meters`` as the parameter name, and a SQL expression such as
``measurement.c.size_meters > 10`` will derive the bound parameter name
from the ``size_meters`` key as well.
.. versionchanged:: 1.0.0 - SQL expressions will use :attr:`.Column.key`
as the source of naming when anonymous bound parameters are created
in SQL expressions; previously, this behavior only applied to
:meth:`.Table.insert` and :meth:`.Table.update` parameter names.
The other solution is to use a positional format; psycopg2 allows use of the
"format" paramstyle, which can be passed to
:paramref:`.create_engine.paramstyle`::
engine = create_engine(
'postgresql://scott:tiger@localhost:5432/test', paramstyle='format')
With the above engine, instead of a statement like::
INSERT INTO measurement ("Size (meters)") VALUES (%(Size (meters))s)
{'Size (meters)': 1}
we instead see::
INSERT INTO measurement ("Size (meters)") VALUES (%s)
(1, )
Where above, the dictionary style is converted into a tuple with positional
style.
* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
this feature. What this essentially means from a psycopg2 point of view is that the cursor is
created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
are not immediately pre-fetched and buffered after statement execution, but are instead left
on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows
at a time are fetched over the wire to reduce conversational overhead.
* *use_native_unicode* - Enable the usage of Psycopg2 "native unicode" mode per connection. True
by default.
* *isolation_level* - Sets the transaction isolation level for each transaction
within the engine. Valid isolation levels are `READ_COMMITTED`,
`READ_UNCOMMITTED`, `REPEATABLE_READ`, and `SERIALIZABLE`.
Transactions
------------
The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
.. _psycopg2_isolation_level:
Psycopg2 Transaction Isolation Level
-------------------------------------
As discussed in :ref:`postgresql_isolation_level`,
all PostgreSQL dialects support setting of transaction isolation level
both via the ``isolation_level`` parameter passed to :func:`.create_engine`,
as well as the ``isolation_level`` argument used by
:meth:`.Connection.execution_options`. When using the psycopg2 dialect, these
options make use of psycopg2's ``set_isolation_level()`` connection method,
rather than emitting a PostgreSQL directive; this is because psycopg2's
API-level setting is always emitted at the start of each transaction in any
case.
The psycopg2 dialect supports these constants for isolation level:
* ``READ COMMITTED``
* ``READ UNCOMMITTED``
* ``REPEATABLE READ``
* ``SERIALIZABLE``
* ``AUTOCOMMIT``
.. versionadded:: 0.8.2 support for AUTOCOMMIT isolation level when using
psycopg2.
.. seealso::
:ref:`postgresql_isolation_level`
:ref:`pg8000_isolation_level`
NOTICE logging
---------------
The psycopg2 dialect will log Postgresql NOTICE messages via the
The psycopg2 dialect will log PostgreSQL NOTICE messages via the
``sqlalchemy.dialects.postgresql`` logger::
import logging
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
.. _psycopg2_hstore::
Per-Statement Execution Options
-------------------------------
HSTORE type
------------
The following per-statement execution options are respected:
The ``psycopg2`` DBAPI includes an extension to natively handle marshalling of
the HSTORE type. The SQLAlchemy psycopg2 dialect will enable this extension
by default when psycopg2 version 2.4 or greater is used, and
it is detected that the target database has the HSTORE type set up for use.
In other words, when the dialect makes the first
connection, a sequence like the following is performed:
* *stream_results* - Enable or disable usage of server side cursors for the SELECT-statement.
If *None* or not set, the *server_side_cursors* option of the connection is used. If
auto-commit is enabled, the option is ignored.
1. Request the available HSTORE oids using
``psycopg2.extras.HstoreAdapter.get_oids()``.
If this function returns a list of HSTORE identifiers, we then determine
that the ``HSTORE`` extension is present.
This function is **skipped** if the version of psycopg2 installed is
less than version 2.4.
2. If the ``use_native_hstore`` flag is at its default of ``True``, and
we've detected that ``HSTORE`` oids are available, the
``psycopg2.extensions.register_hstore()`` extension is invoked for all
connections.
The ``register_hstore()`` extension has the effect of **all Python
dictionaries being accepted as parameters regardless of the type of target
column in SQL**. The dictionaries are converted by this extension into a
textual HSTORE expression. If this behavior is not desired, disable the
use of the hstore extension by setting ``use_native_hstore`` to ``False`` as
follows::
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
use_native_hstore=False)
The ``HSTORE`` type is **still supported** when the
``psycopg2.extensions.register_hstore()`` extension is not used. It merely
means that the coercion between Python dictionaries and the HSTORE
string format, on both the parameter side and the result side, will take
place within SQLAlchemy's own marshalling logic, and not that of ``psycopg2``
which may be more performant.
"""
from __future__ import absolute_import
import random
import re
import decimal
import logging
from sqlalchemy import util
from sqlalchemy import processors
from sqlalchemy.engine import base, default
from sqlalchemy.sql import expression
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, \
PGIdentifierPreparer, PGExecutionContext, \
ENUM, ARRAY
from ... import util, exc
import decimal
from ... import processors
from ...engine import result as _result
from ...sql import expression
from ... import types as sqltypes
from .base import PGDialect, PGCompiler, \
PGIdentifierPreparer, PGExecutionContext, \
ENUM, _DECIMAL_TYPES, _FLOAT_TYPES,\
_INT_TYPES, UUID
from .hstore import HSTORE
from .json import JSON, JSONB
try:
from uuid import UUID as _python_UUID
except ImportError:
_python_UUID = None
logger = logging.getLogger('sqlalchemy.dialects.postgresql')
@ -80,82 +341,113 @@ class _PGNumeric(sqltypes.Numeric):
def result_processor(self, dialect, coltype):
if self.asdecimal:
if coltype in (700, 701):
return processors.to_decimal_processor_factory(decimal.Decimal)
elif coltype == 1700:
if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
decimal.Decimal,
self._effective_decimal_return_scale)
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# pg8000 returns Decimal natively for 1700
return None
else:
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
else:
if coltype in (700, 701):
if coltype in _FLOAT_TYPES:
# pg8000 returns float natively for 701
return None
elif coltype == 1700:
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
return processors.to_float
else:
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
class _PGEnum(ENUM):
def __init__(self, *arg, **kw):
super(_PGEnum, self).__init__(*arg, **kw)
if self.convert_unicode:
self.convert_unicode = "force"
def result_processor(self, dialect, coltype):
if self.native_enum and util.py2k and self.convert_unicode is True:
# we can't easily use PG's extensions here because
# the OID is on the fly, and we need to give it a python
# function anyway - not really worth it.
self.convert_unicode = "force_nocheck"
return super(_PGEnum, self).result_processor(dialect, coltype)
class _PGArray(ARRAY):
def __init__(self, *arg, **kw):
super(_PGArray, self).__init__(*arg, **kw)
# FIXME: this check won't work for setups that
# have convert_unicode only on their create_engine().
if isinstance(self.item_type, sqltypes.String) and \
self.item_type.convert_unicode:
self.item_type.convert_unicode = "force"
# When we're handed literal SQL, ensure it's a SELECT-query. Since
# 8.3, combining cursors and "FOR UPDATE" has been fine.
SERVER_SIDE_CURSOR_RE = re.compile(
r'\s*SELECT',
re.I | re.UNICODE)
class _PGHStore(HSTORE):
def bind_processor(self, dialect):
if dialect._has_native_hstore:
return None
else:
return super(_PGHStore, self).bind_processor(dialect)
def result_processor(self, dialect, coltype):
if dialect._has_native_hstore:
return None
else:
return super(_PGHStore, self).result_processor(dialect, coltype)
class _PGJSON(JSON):
def result_processor(self, dialect, coltype):
if dialect._has_native_json:
return None
else:
return super(_PGJSON, self).result_processor(dialect, coltype)
class _PGJSONB(JSONB):
def result_processor(self, dialect, coltype):
if dialect._has_native_jsonb:
return None
else:
return super(_PGJSONB, self).result_processor(dialect, coltype)
class _PGUUID(UUID):
def bind_processor(self, dialect):
if not self.as_uuid and dialect.use_native_uuid:
nonetype = type(None)
def process(value):
if value is not None:
value = _python_UUID(value)
return value
return process
def result_processor(self, dialect, coltype):
if not self.as_uuid and dialect.use_native_uuid:
def process(value):
if value is not None:
value = str(value)
return value
return process
_server_side_id = util.counter()
class PGExecutionContext_psycopg2(PGExecutionContext):
def create_cursor(self):
# TODO: coverage for server side cursors + select.for_update()
if self.dialect.server_side_cursors:
is_server_side = \
self.execution_options.get('stream_results', True) and (
(self.compiled and isinstance(self.compiled.statement, expression.Selectable) \
or \
(
(not self.compiled or
isinstance(self.compiled.statement, expression._TextClause))
and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
)
)
else:
is_server_side = self.execution_options.get('stream_results', False)
self.__is_server_side = is_server_side
if is_server_side:
# use server-side cursors:
# http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
return self._connection.connection.cursor(ident)
else:
return self._connection.connection.cursor()
def create_server_side_cursor(self):
# use server-side cursors:
# http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
ident = "c_%s_%s" % (hex(id(self))[2:],
hex(_server_side_id())[2:])
return self._dbapi_connection.cursor(ident)
def get_result_proxy(self):
# TODO: ouch
if logger.isEnabledFor(logging.INFO):
self._log_notices(self.cursor)
if self.__is_server_side:
return base.BufferedRowResultProxy(self)
if self._is_server_side:
return _result.BufferedRowResultProxy(self)
else:
return base.ResultProxy(self)
return _result.ResultProxy(self)
def _log_notices(self, cursor):
for notice in cursor.connection.notices:
# NOTICE messages have a
# NOTICE messages have a
# newline character at the end
logger.info(notice.rstrip())
@ -163,9 +455,10 @@ class PGExecutionContext_psycopg2(PGExecutionContext):
class PGCompiler_psycopg2(PGCompiler):
def visit_mod(self, binary, **kw):
return self.process(binary.left) + " %% " + self.process(binary.right)
def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)
def post_process_text(self, text):
return text.replace('%', '%%')
@ -175,47 +468,191 @@ class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace('%', '%%')
class PGDialect_psycopg2(PGDialect):
driver = 'psycopg2'
supports_unicode_statements = False
if util.py2k:
supports_unicode_statements = False
supports_server_side_cursors = True
default_paramstyle = 'pyformat'
# set to true based on psycopg2 version
supports_sane_multi_rowcount = False
execution_ctx_cls = PGExecutionContext_psycopg2
statement_compiler = PGCompiler_psycopg2
preparer = PGIdentifierPreparer_psycopg2
psycopg2_version = (0, 0)
FEATURE_VERSION_MAP = dict(
native_json=(2, 5),
native_jsonb=(2, 5, 4),
sane_multi_rowcount=(2, 0, 9),
array_oid=(2, 4, 3),
hstore_adapter=(2, 4)
)
_has_native_hstore = False
_has_native_json = False
_has_native_jsonb = False
engine_config_types = PGDialect.engine_config_types.union([
('use_native_unicode', util.asbool),
])
colspecs = util.update_copy(
PGDialect.colspecs,
{
sqltypes.Numeric : _PGNumeric,
ENUM : _PGEnum, # needs force_unicode
sqltypes.Enum : _PGEnum, # needs force_unicode
ARRAY : _PGArray, # needs force_unicode
sqltypes.Numeric: _PGNumeric,
ENUM: _PGEnum, # needs force_unicode
sqltypes.Enum: _PGEnum, # needs force_unicode
HSTORE: _PGHStore,
JSON: _PGJSON,
sqltypes.JSON: _PGJSON,
JSONB: _PGJSONB,
UUID: _PGUUID
}
)
def __init__(self, server_side_cursors=False, use_native_unicode=True, **kwargs):
def __init__(self, server_side_cursors=False, use_native_unicode=True,
client_encoding=None,
use_native_hstore=True, use_native_uuid=True,
**kwargs):
PGDialect.__init__(self, **kwargs)
self.server_side_cursors = server_side_cursors
self.use_native_unicode = use_native_unicode
self.use_native_hstore = use_native_hstore
self.use_native_uuid = use_native_uuid
self.supports_unicode_binds = use_native_unicode
self.client_encoding = client_encoding
if self.dbapi and hasattr(self.dbapi, '__version__'):
m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?',
self.dbapi.__version__)
if m:
self.psycopg2_version = tuple(
int(x)
for x in m.group(1, 2, 3)
if x is not None)
def initialize(self, connection):
super(PGDialect_psycopg2, self).initialize(connection)
self._has_native_hstore = self.use_native_hstore and \
self._hstore_oids(connection.connection) \
is not None
self._has_native_json = \
self.psycopg2_version >= self.FEATURE_VERSION_MAP['native_json']
self._has_native_jsonb = \
self.psycopg2_version >= self.FEATURE_VERSION_MAP['native_jsonb']
# http://initd.org/psycopg/docs/news.html#what-s-new-in-psycopg-2-0-9
self.supports_sane_multi_rowcount = \
self.psycopg2_version >= \
self.FEATURE_VERSION_MAP['sane_multi_rowcount']
@classmethod
def dbapi(cls):
psycopg = __import__('psycopg2')
return psycopg
import psycopg2
return psycopg2
@classmethod
def _psycopg2_extensions(cls):
from psycopg2 import extensions
return extensions
@classmethod
def _psycopg2_extras(cls):
from psycopg2 import extras
return extras
@util.memoized_property
def _isolation_lookup(self):
extensions = self._psycopg2_extensions()
return {
'AUTOCOMMIT': extensions.ISOLATION_LEVEL_AUTOCOMMIT,
'READ COMMITTED': extensions.ISOLATION_LEVEL_READ_COMMITTED,
'READ UNCOMMITTED': extensions.ISOLATION_LEVEL_READ_UNCOMMITTED,
'REPEATABLE READ': extensions.ISOLATION_LEVEL_REPEATABLE_READ,
'SERIALIZABLE': extensions.ISOLATION_LEVEL_SERIALIZABLE
}
def set_isolation_level(self, connection, level):
try:
level = self._isolation_lookup[level.replace('_', ' ')]
except KeyError:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
"Valid isolation levels for %s are %s" %
(level, self.name, ", ".join(self._isolation_lookup))
)
connection.set_isolation_level(level)
def on_connect(self):
base_on_connect = super(PGDialect_psycopg2, self).on_connect()
extras = self._psycopg2_extras()
extensions = self._psycopg2_extensions()
fns = []
if self.client_encoding is not None:
def on_connect(conn):
conn.set_client_encoding(self.client_encoding)
fns.append(on_connect)
if self.isolation_level is not None:
def on_connect(conn):
self.set_isolation_level(conn, self.isolation_level)
fns.append(on_connect)
if self.dbapi and self.use_native_uuid:
def on_connect(conn):
extras.register_uuid(None, conn)
fns.append(on_connect)
if self.dbapi and self.use_native_unicode:
extensions = __import__('psycopg2.extensions').extensions
def connect(conn):
def on_connect(conn):
extensions.register_type(extensions.UNICODE, conn)
if base_on_connect:
base_on_connect(conn)
return connect
extensions.register_type(extensions.UNICODEARRAY, conn)
fns.append(on_connect)
if self.dbapi and self.use_native_hstore:
def on_connect(conn):
hstore_oids = self._hstore_oids(conn)
if hstore_oids is not None:
oid, array_oid = hstore_oids
kw = {'oid': oid}
if util.py2k:
kw['unicode'] = True
if self.psycopg2_version >= \
self.FEATURE_VERSION_MAP['array_oid']:
kw['array_oid'] = array_oid
extras.register_hstore(conn, **kw)
fns.append(on_connect)
if self.dbapi and self._json_deserializer:
def on_connect(conn):
if self._has_native_json:
extras.register_default_json(
conn, loads=self._json_deserializer)
if self._has_native_jsonb:
extras.register_default_jsonb(
conn, loads=self._json_deserializer)
fns.append(on_connect)
if fns:
def on_connect(conn):
for fn in fns:
fn(conn)
return on_connect
else:
return base_on_connect
return None
@util.memoized_instancemethod
def _hstore_oids(self, conn):
if self.psycopg2_version >= self.FEATURE_VERSION_MAP['hstore_adapter']:
extras = self._psycopg2_extras()
oids = extras.HstoreAdapter.get_oids(conn)
if oids is not None and oids[0]:
return oids[0:2]
return None
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
@ -224,16 +661,42 @@ class PGDialect_psycopg2(PGDialect):
opts.update(url.query)
return ([], opts)
def is_disconnect(self, e):
if isinstance(e, self.dbapi.OperationalError):
return 'closed the connection' in str(e) or 'connection not open' in str(e)
elif isinstance(e, self.dbapi.InterfaceError):
return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
elif isinstance(e, self.dbapi.ProgrammingError):
# yes, it really says "losed", not "closed"
return "losed the connection unexpectedly" in str(e)
else:
return False
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.Error):
# check the "closed" flag. this might not be
# present on old psycopg2 versions. Also,
# this flag doesn't actually help in a lot of disconnect
# situations, so don't rely on it.
if getattr(connection, 'closed', False):
return True
# checks based on strings. in the case that .closed
# didn't cut it, fall back onto these.
str_e = str(e).partition("\n")[0]
for msg in [
# these error messages from libpq: interfaces/libpq/fe-misc.c
# and interfaces/libpq/fe-secure.c.
'terminating connection',
'closed the connection',
'connection not open',
'could not receive data from server',
'could not send data to server',
# psycopg2 client errors, psycopg2/conenction.h,
# psycopg2/cursor.h
'connection already closed',
'cursor already closed',
# not sure where this path is originally from, it may
# be obsolete. It really says "losed", not "closed".
'losed the connection unexpectedly',
# these can occur in newer SSL
'connection has been closed unexpectedly',
'SSL SYSCALL error: Bad file descriptor',
'SSL SYSCALL error: EOF detected',
'SSL error: decryption failed or bad record mac',
]:
idx = str_e.find(msg)
if idx >= 0 and '"' not in str_e[:idx]:
return True
return False
dialect = PGDialect_psycopg2

View File

@ -1,18 +1,25 @@
"""Support for the PostgreSQL database via py-postgresql.
# postgresql/pypostgresql.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
Connecting
----------
URLs are of the form `postgresql+pypostgresql://user@password@host:port/dbname[?key=value&key=value...]`.
"""
.. dialect:: postgresql+pypostgresql
:name: py-postgresql
:dbapi: pypostgresql
:connectstring: postgresql+pypostgresql://user:password@host:port/dbname\
[?key=value&key=value...]
:url: http://python.projects.pgfoundry.org/
"""
from sqlalchemy.engine import default
import decimal
from sqlalchemy import util
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext
from sqlalchemy import processors
from ... import util
from ... import types as sqltypes
from .base import PGDialect, PGExecutionContext
from ... import processors
class PGNumeric(sqltypes.Numeric):
def bind_processor(self, dialect):
@ -24,9 +31,11 @@ class PGNumeric(sqltypes.Numeric):
else:
return processors.to_float
class PGExecutionContext_pypostgresql(PGExecutionContext):
pass
class PGDialect_pypostgresql(PGDialect):
driver = 'pypostgresql'
@ -36,7 +45,7 @@ class PGDialect_pypostgresql(PGDialect):
default_paramstyle = 'pyformat'
# requires trunk version to support sane rowcounts
# TODO: use dbapi version information to set this flag appropariately
# TODO: use dbapi version information to set this flag appropriately
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
@ -44,8 +53,10 @@ class PGDialect_pypostgresql(PGDialect):
colspecs = util.update_copy(
PGDialect.colspecs,
{
sqltypes.Numeric : PGNumeric,
sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used
sqltypes.Numeric: PGNumeric,
# prevents PGNumeric from being used
sqltypes.Float: sqltypes.Float,
}
)
@ -54,6 +65,23 @@ class PGDialect_pypostgresql(PGDialect):
from postgresql.driver import dbapi20
return dbapi20
_DBAPI_ERROR_NAMES = [
"Error",
"InterfaceError", "DatabaseError", "DataError",
"OperationalError", "IntegrityError", "InternalError",
"ProgrammingError", "NotSupportedError"
]
@util.memoized_property
def dbapi_exception_translation_map(self):
if self.dbapi is None:
return {}
return dict(
(getattr(self.dbapi, name).__name__, name)
for name in self._DBAPI_ERROR_NAMES
)
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if 'port' in opts:
@ -63,7 +91,7 @@ class PGDialect_pypostgresql(PGDialect):
opts.update(url.query)
return ([], opts)
def is_disconnect(self, e):
def is_disconnect(self, e, connection, cursor):
return "connection is closed" in str(e)
dialect = PGDialect_pypostgresql

View File

@ -1,19 +1,46 @@
"""Support for the PostgreSQL database via the zxjdbc JDBC connector.
JDBC Driver
-----------
The official Postgresql JDBC driver is at http://jdbc.postgresql.org/.
# postgresql/zxjdbc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
from sqlalchemy.dialects.postgresql.base import PGDialect
.. dialect:: postgresql+zxjdbc
:name: zxJDBC for Jython
:dbapi: zxjdbc
:connectstring: postgresql+zxjdbc://scott:tiger@localhost/db
:driverurl: http://jdbc.postgresql.org/
"""
from ...connectors.zxJDBC import ZxJDBCConnector
from .base import PGDialect, PGExecutionContext
class PGExecutionContext_zxjdbc(PGExecutionContext):
def create_cursor(self):
cursor = self._dbapi_connection.cursor()
cursor.datahandler = self.dialect.DataHandler(cursor.datahandler)
return cursor
class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect):
jdbc_db_name = 'postgresql'
jdbc_driver_name = 'org.postgresql.Driver'
execution_ctx_cls = PGExecutionContext_zxjdbc
supports_native_decimal = True
def __init__(self, *args, **kwargs):
super(PGDialect_zxjdbc, self).__init__(*args, **kwargs)
from com.ziclix.python.sql.handler import PostgresqlDataHandler
self.DataHandler = PostgresqlDataHandler
def _get_server_version_info(self, connection):
return tuple(int(x) for x in connection.connection.dbversion.split('.'))
parts = connection.connection.dbversion.split('.')
return tuple(int(x) for x in parts)
dialect = PGDialect_zxjdbc

View File

@ -5,20 +5,20 @@ Rules for Migrating TypeEngine classes to 0.6
a. Specifying behavior which needs to occur for bind parameters
or result row columns.
b. Specifying types that are entirely specific to the database
in use and have no analogue in the sqlalchemy.types package.
c. Specifying types where there is an analogue in sqlalchemy.types,
but the database in use takes vendor-specific flags for those
types.
d. If a TypeEngine class doesn't provide any of this, it should be
*removed* from the dialect.
2. the TypeEngine classes are *no longer* used for generating DDL. Dialects
now have a TypeCompiler subclass which uses the same visit_XXX model as
other compilers.
other compilers.
3. the "ischema_names" and "colspecs" dictionaries are now required members on
the Dialect class.
@ -29,7 +29,7 @@ the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this
end users would never need to use _PGNumeric directly. However, if a dialect-specific
type is specifying a type *or* arguments that are not present generically, it should
match the real name of the type on that backend, in uppercase. E.g. postgresql.INET,
mysql.ENUM, postgresql.ARRAY.
mysql.ENUM, postgresql.ARRAY.
Or follow this handy flowchart:
@ -61,8 +61,8 @@ Or follow this handy flowchart:
|
v
the type should
subclass the
UPPERCASE
subclass the
UPPERCASE
type in types.py
(i.e. class BLOB(types.BLOB))
@ -85,15 +85,15 @@ Example 4. MySQL has a SET type, there's no analogue for this in types.py. So
MySQL names it SET in the dialect's base.py, and it subclasses types.String, since
it ultimately deals with strings.
Example 5. Postgresql has a DATETIME type. The DBAPIs handle dates correctly,
and no special arguments are used in PG's DDL beyond what types.py provides.
Postgresql dialect therefore imports types.DATETIME into its base.py.
Example 5. PostgreSQL has a DATETIME type. The DBAPIs handle dates correctly,
and no special arguments are used in PG's DDL beyond what types.py provides.
PostgreSQL dialect therefore imports types.DATETIME into its base.py.
Ideally one should be able to specify a schema using names imported completely from a
dialect, all matching the real name on that backend:
from sqlalchemy.dialects.postgresql import base as pg
t = Table('mytable', metadata,
Column('id', pg.INTEGER, primary_key=True),
Column('name', pg.VARCHAR(300)),
@ -110,36 +110,36 @@ indicate a special type only available in this database, it must be *removed* fr
module and from this dictionary.
6. "ischema_names" indicates string descriptions of types as returned from the database
linked to TypeEngine classes.
linked to TypeEngine classes.
a. The string name should be matched to the most specific type possible within
sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which
case it points to a dialect type. *It doesn't matter* if the dialect has it's
case it points to a dialect type. *It doesn't matter* if the dialect has its
own subclass of that type with special bind/result behavior - reflect to the types.py
UPPERCASE type as much as possible. With very few exceptions, all types
should reflect to an UPPERCASE type.
b. If the dialect contains a matching dialect-specific type that takes extra arguments
which the generic one does not, then point to the dialect-specific type. E.g.
mssql.VARCHAR takes a "collation" parameter which should be preserved.
5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by
a subclass of compiler.GenericTypeCompiler.
a. your TypeCompiler class will receive generic and uppercase types from
sqlalchemy.types. Do not assume the presence of dialect-specific attributes on
these types.
b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with
methods that produce a different DDL name. Uppercase types don't do any kind of
"guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in
all cases, regardless of whether or not that type is legal on the backend database.
c. the visit_UPPERCASE methods *should* be overridden with methods that add additional
arguments and flags to those types.
arguments and flags to those types.
d. the visit_lowercase methods are overridden to provide an interpretation of a generic
type. E.g. visit_large_binary() might be overridden to say "return self.visit_BIT(type_)".
e. visit_lowercase methods should *never* render strings directly - it should always
be via calling a visit_UPPERCASE() method.

View File

@ -1,5 +1,6 @@
# engine/__init__.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -9,7 +10,7 @@
The engine package defines the basic components used to interface
DB-API modules with higher-level statement construction,
connection-management, execution and result contexts. The primary
"entry point" class into this package is the Engine and it's public
"entry point" class into this package is the Engine and its public
constructor ``create_engine()``.
This package includes:
@ -50,94 +51,125 @@ url.py
within a URL.
"""
# not sure what this was used for
#import sqlalchemy.databases
from .interfaces import (
Connectable,
CreateEnginePlugin,
Dialect,
ExecutionContext,
ExceptionContext,
from sqlalchemy.engine.base import (
# backwards compat
Compiled,
TypeCompiler
)
from .base import (
Connection,
Engine,
NestedTransaction,
RootTransaction,
Transaction,
TwoPhaseTransaction,
)
from .result import (
BaseRowProxy,
BufferedColumnResultProxy,
BufferedColumnRow,
BufferedRowResultProxy,
Compiled,
Connectable,
Connection,
Dialect,
Engine,
ExecutionContext,
NestedTransaction,
FullyBufferedResultProxy,
ResultProxy,
RootTransaction,
RowProxy,
Transaction,
TwoPhaseTransaction,
TypeCompiler
)
from sqlalchemy.engine import strategies
from sqlalchemy import util
)
from .util import (
connection_memoize
)
__all__ = (
'BufferedColumnResultProxy',
'BufferedColumnRow',
'BufferedRowResultProxy',
'Compiled',
'Connectable',
'Connection',
'Dialect',
'Engine',
'ExecutionContext',
'NestedTransaction',
'ResultProxy',
'RootTransaction',
'RowProxy',
'Transaction',
'TwoPhaseTransaction',
'TypeCompiler',
'create_engine',
'engine_from_config',
)
from . import util, strategies
# backwards compat
from ..sql import ddl
default_strategy = 'plain'
def create_engine(*args, **kwargs):
"""Create a new Engine instance.
"""Create a new :class:`.Engine` instance.
The standard method of specifying the engine is via URL as the
first positional argument, to indicate the appropriate database
dialect and connection arguments, with additional keyword
arguments sent as options to the dialect and resulting Engine.
The standard calling form is to send the URL as the
first positional argument, usually a string
that indicates database dialect and connection arguments::
The URL is a string in the form
``dialect+driver://user:password@host/dbname[?key=value..]``, where
``dialect`` is a database name such as ``mysql``, ``oracle``,
``postgresql``, etc., and ``driver`` the name of a DBAPI, such as
``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively,
engine = create_engine("postgresql://scott:tiger@localhost/test")
Additional keyword arguments may then follow it which
establish various options on the resulting :class:`.Engine`
and its underlying :class:`.Dialect` and :class:`.Pool`
constructs::
engine = create_engine("mysql://scott:tiger@hostname/dbname",
encoding='latin1', echo=True)
The string form of the URL is
``dialect[+driver]://user:password@host/dbname[?key=value..]``, where
``dialect`` is a database name such as ``mysql``, ``oracle``,
``postgresql``, etc., and ``driver`` the name of a DBAPI, such as
``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively,
the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`.
`**kwargs` takes a wide variety of options which are routed
towards their appropriate components. Arguments may be
specific to the Engine, the underlying Dialect, as well as the
Pool. Specific dialects also accept keyword arguments that
``**kwargs`` takes a wide variety of options which are routed
towards their appropriate components. Arguments may be specific to
the :class:`.Engine`, the underlying :class:`.Dialect`, as well as the
:class:`.Pool`. Specific dialects also accept keyword arguments that
are unique to that dialect. Here, we describe the parameters
that are common to most ``create_engine()`` usage.
that are common to most :func:`.create_engine()` usage.
:param assert_unicode: Deprecated. A warning is raised in all cases when a non-Unicode
object is passed when SQLAlchemy would coerce into an encoding
(note: but **not** when the DBAPI handles unicode objects natively).
To suppress or raise this warning to an
error, use the Python warnings filter documented at:
http://docs.python.org/library/warnings.html
Once established, the newly resulting :class:`.Engine` will
request a connection from the underlying :class:`.Pool` once
:meth:`.Engine.connect` is called, or a method which depends on it
such as :meth:`.Engine.execute` is invoked. The :class:`.Pool` in turn
will establish the first actual DBAPI connection when this request
is received. The :func:`.create_engine` call itself does **not**
establish any actual DBAPI connections directly.
.. seealso::
:doc:`/core/engines`
:doc:`/dialects/index`
:ref:`connections_toplevel`
:param case_sensitive=True: if False, result column names
will match in a case-insensitive fashion, that is,
``row['SomeColumn']``.
.. versionchanged:: 0.8
By default, result row names match case-sensitively.
In version 0.7 and prior, all matches were case-insensitive.
:param connect_args: a dictionary of options which will be
passed directly to the DBAPI's ``connect()`` method as
additional keyword arguments.
additional keyword arguments. See the example
at :ref:`custom_dbapi_args`.
:param convert_unicode=False: if set to True, all
String/character based types will convert Unicode values to raw
byte values going into the database, and all raw byte values to
Python Unicode coming out in result sets. This is an
engine-wide method to provide unicode conversion across the
board. For unicode conversion on a column-by-column level, use
the ``Unicode`` column type instead, described in `types`.
:param convert_unicode=False: if set to True, sets
the default behavior of ``convert_unicode`` on the
:class:`.String` type to ``True``, regardless
of a setting of ``False`` on an individual
:class:`.String` type, thus causing all :class:`.String`
-based columns
to accommodate Python ``unicode`` objects. This flag
is useful as an engine-wide setting when using a
DBAPI that does not natively support Python
``unicode`` objects and raises an error when
one is received (such as pyodbc with FreeTDS).
See :class:`.String` for further details on
what this flag indicates.
:param creator: a callable which returns a DBAPI connection.
This creation function will be passed to the underlying
@ -160,23 +192,105 @@ def create_engine(*args, **kwargs):
:ref:`dbengine_logging` for information on how to configure logging
directly.
:param encoding='utf-8': the encoding to use for all Unicode
translations, both by engine-wide unicode conversion as well as
the ``Unicode`` type object.
:param encoding: Defaults to ``utf-8``. This is the string
encoding used by SQLAlchemy for string encode/decode
operations which occur within SQLAlchemy, **outside of
the DBAPI.** Most modern DBAPIs feature some degree of
direct support for Python ``unicode`` objects,
what you see in Python 2 as a string of the form
``u'some string'``. For those scenarios where the
DBAPI is detected as not supporting a Python ``unicode``
object, this encoding is used to determine the
source/destination encoding. It is **not used**
for those cases where the DBAPI handles unicode
directly.
To properly configure a system to accommodate Python
``unicode`` objects, the DBAPI should be
configured to handle unicode to the greatest
degree as is appropriate - see
the notes on unicode pertaining to the specific
target database in use at :ref:`dialect_toplevel`.
Areas where string encoding may need to be accommodated
outside of the DBAPI include zero or more of:
* the values passed to bound parameters, corresponding to
the :class:`.Unicode` type or the :class:`.String` type
when ``convert_unicode`` is ``True``;
* the values returned in result set columns corresponding
to the :class:`.Unicode` type or the :class:`.String`
type when ``convert_unicode`` is ``True``;
* the string SQL statement passed to the DBAPI's
``cursor.execute()`` method;
* the string names of the keys in the bound parameter
dictionary passed to the DBAPI's ``cursor.execute()``
as well as ``cursor.setinputsizes()`` methods;
* the string column names retrieved from the DBAPI's
``cursor.description`` attribute.
When using Python 3, the DBAPI is required to support
*all* of the above values as Python ``unicode`` objects,
which in Python 3 are just known as ``str``. In Python 2,
the DBAPI does not specify unicode behavior at all,
so SQLAlchemy must make decisions for each of the above
values on a per-DBAPI basis - implementations are
completely inconsistent in their behavior.
:param execution_options: Dictionary execution options which will
be applied to all connections. See
:meth:`~sqlalchemy.engine.Connection.execution_options`
:param implicit_returning=True: When ``True``, a RETURNING-
compatible construct, if available, will be used to
fetch newly generated primary key values when a single row
INSERT statement is emitted with no existing returning()
clause. This applies to those backends which support RETURNING
or a compatible construct, including PostgreSQL, Firebird, Oracle,
Microsoft SQL Server. Set this to ``False`` to disable
the automatic usage of RETURNING.
:param isolation_level: this string parameter is interpreted by various
dialects in order to affect the transaction isolation level of the
database connection. The parameter essentially accepts some subset of
these string arguments: ``"SERIALIZABLE"``, ``"REPEATABLE_READ"``,
``"READ_COMMITTED"``, ``"READ_UNCOMMITTED"`` and ``"AUTOCOMMIT"``.
Behavior here varies per backend, and
individual dialects should be consulted directly.
Note that the isolation level can also be set on a per-:class:`.Connection`
basis as well, using the
:paramref:`.Connection.execution_options.isolation_level`
feature.
.. seealso::
:attr:`.Connection.default_isolation_level` - view default level
:paramref:`.Connection.execution_options.isolation_level`
- set per :class:`.Connection` isolation level
:ref:`SQLite Transaction Isolation <sqlite_isolation_level>`
:ref:`PostgreSQL Transaction Isolation <postgresql_isolation_level>`
:ref:`MySQL Transaction Isolation <mysql_isolation_level>`
:ref:`session_transaction_isolation` - for the ORM
:param label_length=None: optional integer value which limits
the size of dynamically generated column labels to that many
characters. If less than 6, labels are generated as
"_(counter)". If ``None``, the value of
``dialect.max_identifier_length`` is used instead.
:param listeners: A list of one or more
:class:`~sqlalchemy.interfaces.PoolListener` objects which will
:param listeners: A list of one or more
:class:`~sqlalchemy.interfaces.PoolListener` objects which will
receive connection pool events.
:param logging_name: String identifier which will be used within
the "name" field of logging records generated within the
"sqlalchemy.engine" logger. Defaults to a hexstring of the
"sqlalchemy.engine" logger. Defaults to a hexstring of the
object's id.
:param max_overflow=10: the number of connections to allow in
@ -184,10 +298,24 @@ def create_engine(*args, **kwargs):
opened above and beyond the pool_size setting, which defaults
to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`.
:param module=None: used by database implementations which
support multiple DBAPI modules, this is a reference to a DBAPI2
module to be used instead of the engine's default module. For
PostgreSQL, the default is psycopg2. For Oracle, it's cx_Oracle.
:param module=None: reference to a Python module object (the module
itself, not its string name). Specifies an alternate DBAPI module to
be used by the engine's dialect. Each sub-dialect references a
specific DBAPI which will be imported before first connect. This
parameter causes the import to be bypassed, and the given module to
be used instead. Can be used for testing of DBAPIs as well as to
inject "mock" DBAPI implementations into the :class:`.Engine`.
:param paramstyle=None: The `paramstyle <http://legacy.python.org/dev/peps/pep-0249/#paramstyle>`_
to use when rendering bound parameters. This style defaults to the
one recommended by the DBAPI itself, which is retrieved from the
``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept
more than one paramstyle, and in particular it may be desirable
to change a "named" paramstyle into a "positional" one, or vice versa.
When this attribute is passed, it should be one of the values
``"qmark"``, ``"numeric"``, ``"named"``, ``"format"`` or
``"pyformat"``, and should correspond to a parameter style known
to be supported by the DBAPI in use.
:param pool=None: an already-constructed instance of
:class:`~sqlalchemy.pool.Pool`, such as a
@ -195,7 +323,7 @@ def create_engine(*args, **kwargs):
pool will be used directly as the underlying connection pool
for the engine, bypassing whatever connection parameters are
present in the URL argument. For information on constructing
connection pools manually, see `pooling`.
connection pools manually, see :ref:`pooling_toplevel`.
:param poolclass=None: a :class:`~sqlalchemy.pool.Pool`
subclass, which will be used to create a connection pool
@ -205,70 +333,102 @@ def create_engine(*args, **kwargs):
of pool to be used.
:param pool_logging_name: String identifier which will be used within
the "name" field of logging records generated within the
"sqlalchemy.pool" logger. Defaults to a hexstring of the object's
the "name" field of logging records generated within the
"sqlalchemy.pool" logger. Defaults to a hexstring of the object's
id.
:param pool_size=5: the number of connections to keep open
inside the connection pool. This used with :class:`~sqlalchemy.pool.QueuePool` as
well as :class:`~sqlalchemy.pool.SingletonThreadPool`.
inside the connection pool. This used with
:class:`~sqlalchemy.pool.QueuePool` as
well as :class:`~sqlalchemy.pool.SingletonThreadPool`. With
:class:`~sqlalchemy.pool.QueuePool`, a ``pool_size`` setting
of 0 indicates no limit; to disable pooling, set ``poolclass`` to
:class:`~sqlalchemy.pool.NullPool` instead.
:param pool_recycle=-1: this setting causes the pool to recycle
connections after the given number of seconds has passed. It
defaults to -1, or no timeout. For example, setting to 3600
means connections will be recycled after one hour. Note that
MySQL in particular will ``disconnect automatically`` if no
MySQL in particular will disconnect automatically if no
activity is detected on a connection for eight hours (although
this is configurable with the MySQLDB connection itself and the
server configuration as well).
:param pool_reset_on_return='rollback': set the "reset on return"
behavior of the pool, which is whether ``rollback()``,
``commit()``, or nothing is called upon connections
being returned to the pool. See the docstring for
``reset_on_return`` at :class:`.Pool`.
.. versionadded:: 0.7.6
:param pool_timeout=30: number of seconds to wait before giving
up on getting a connection from the pool. This is only used
with :class:`~sqlalchemy.pool.QueuePool`.
:param strategy='plain': used to invoke alternate :class:`~sqlalchemy.engine.base.Engine.`
implementations. Currently available is the ``threadlocal``
strategy, which is described in :ref:`threadlocal_strategy`.
:param strategy='plain': selects alternate engine implementations.
Currently available are:
* the ``threadlocal`` strategy, which is described in
:ref:`threadlocal_strategy`;
* the ``mock`` strategy, which dispatches all statement
execution to a function passed as the argument ``executor``.
See `example in the FAQ
<http://docs.sqlalchemy.org/en/latest/faq/metadata_schema.html#how-can-i-get-the-create-table-drop-table-output-as-a-string>`_.
:param executor=None: a function taking arguments
``(sql, *multiparams, **params)``, to which the ``mock`` strategy will
dispatch all statement execution. Used only by ``strategy='mock'``.
"""
strategy = kwargs.pop('strategy', default_strategy)
strategy = strategies.strategies[strategy]
return strategy.create(*args, **kwargs)
def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
"""Create a new Engine instance using a configuration dictionary.
The dictionary is typically produced from a config file where keys
are prefixed, such as sqlalchemy.url, sqlalchemy.echo, etc. The
'prefix' argument indicates the prefix to be searched for.
The dictionary is typically produced from a config file.
The keys of interest to ``engine_from_config()`` should be prefixed, e.g.
``sqlalchemy.url``, ``sqlalchemy.echo``, etc. The 'prefix' argument
indicates the prefix to be searched for. Each matching key (after the
prefix is stripped) is treated as though it were the corresponding keyword
argument to a :func:`.create_engine` call.
The only required key is (assuming the default prefix) ``sqlalchemy.url``,
which provides the :ref:`database URL <database_urls>`.
A select set of keyword arguments will be "coerced" to their
expected type based on string values. In a future release, this
functionality will be expanded and include dialect-specific
arguments.
expected type based on string values. The set of arguments
is extensible per-dialect using the ``engine_config_types`` accessor.
:param configuration: A dictionary (typically produced from a config file,
but this is not a requirement). Items whose keys start with the value
of 'prefix' will have that prefix stripped, and will then be passed to
:ref:`create_engine`.
:param prefix: Prefix to match and then strip from keys
in 'configuration'.
:param kwargs: Each keyword argument to ``engine_from_config()`` itself
overrides the corresponding item taken from the 'configuration'
dictionary. Keyword arguments should *not* be prefixed.
"""
opts = _coerce_config(configuration, prefix)
opts.update(kwargs)
url = opts.pop('url')
return create_engine(url, **opts)
def _coerce_config(configuration, prefix):
"""Convert configuration values to expected types."""
options = dict((key[len(prefix):], configuration[key])
for key in configuration
if key.startswith(prefix))
for option, type_ in (
('convert_unicode', bool),
('pool_timeout', int),
('echo', bool),
('echo_pool', bool),
('pool_recycle', int),
('pool_size', int),
('max_overflow', int),
('pool_threadlocal', bool),
):
util.coerce_kw_type(options, option, type_)
return options
options['_coerce_config'] = True
options.update(kwargs)
url = options.pop('url')
return create_engine(url, **options)
__all__ = (
'create_engine',
'engine_from_config',
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,10 @@
# engine/reflection.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Provides an abstraction for obtaining database schema information.
Usage Notes:
@ -18,11 +25,14 @@ methods such as get_table_names, get_columns, etc.
'name' attribute..
"""
import sqlalchemy
from sqlalchemy import exc, sql
from sqlalchemy import util
from sqlalchemy.types import TypeEngine
from sqlalchemy import schema as sa_schema
from .. import exc, sql
from ..sql import schema as sa_schema
from .. import util
from ..sql.type_api import TypeEngine
from ..util import deprecated
from ..util import topological
from .. import inspection
from .base import Connectable
@util.decorator
@ -31,10 +41,14 @@ def cache(fn, self, con, *args, **kw):
if info_cache is None:
return fn(self, con, *args, **kw)
key = (
fn.__name__,
tuple(a for a in args if isinstance(a, basestring)),
tuple((k, v) for k, v in kw.iteritems() if isinstance(v, (basestring, int, float)))
)
fn.__name__,
tuple(a for a in args if isinstance(a, util.string_types)),
tuple((k, v) for k, v in kw.items() if
isinstance(v,
util.string_types + util.int_types + (float, )
)
)
)
ret = info_cache.get(key)
if ret is None:
ret = fn(self, con, *args, **kw)
@ -45,33 +59,94 @@ def cache(fn, self, con, *args, **kw):
class Inspector(object):
"""Performs database schema inspection.
The Inspector acts as a proxy to the dialects' reflection methods and
provides higher level functions for accessing database schema information.
The Inspector acts as a proxy to the reflection methods of the
:class:`~sqlalchemy.engine.interfaces.Dialect`, providing a
consistent interface as well as caching support for previously
fetched metadata.
A :class:`.Inspector` object is usually created via the
:func:`.inspect` function::
from sqlalchemy import inspect, create_engine
engine = create_engine('...')
insp = inspect(engine)
The inspection method above is equivalent to using the
:meth:`.Inspector.from_engine` method, i.e.::
engine = create_engine('...')
insp = Inspector.from_engine(engine)
Where above, the :class:`~sqlalchemy.engine.interfaces.Dialect` may opt
to return an :class:`.Inspector` subclass that provides additional
methods specific to the dialect's target database.
"""
def __init__(self, conn):
"""Initialize the instance.
def __init__(self, bind):
"""Initialize a new :class:`.Inspector`.
:param bind: a :class:`~sqlalchemy.engine.Connectable`,
which is typically an instance of
:class:`~sqlalchemy.engine.Engine` or
:class:`~sqlalchemy.engine.Connection`.
For a dialect-specific instance of :class:`.Inspector`, see
:meth:`.Inspector.from_engine`
:param conn: a :class:`~sqlalchemy.engine.base.Connectable`
"""
# this might not be a connection, it could be an engine.
self.bind = bind
self.conn = conn
# set the engine
if hasattr(conn, 'engine'):
self.engine = conn.engine
if hasattr(bind, 'engine'):
self.engine = bind.engine
else:
self.engine = conn
self.engine = bind
if self.engine is bind:
# if engine, ensure initialized
bind.connect().close()
self.dialect = self.engine.dialect
self.info_cache = {}
@classmethod
def from_engine(cls, engine):
if hasattr(engine.dialect, 'inspector'):
return engine.dialect.inspector(engine)
return Inspector(engine)
def from_engine(cls, bind):
"""Construct a new dialect-specific Inspector object from the given
engine or connection.
:param bind: a :class:`~sqlalchemy.engine.Connectable`,
which is typically an instance of
:class:`~sqlalchemy.engine.Engine` or
:class:`~sqlalchemy.engine.Connection`.
This method differs from direct a direct constructor call of
:class:`.Inspector` in that the
:class:`~sqlalchemy.engine.interfaces.Dialect` is given a chance to
provide a dialect-specific :class:`.Inspector` instance, which may
provide additional methods.
See the example at :class:`.Inspector`.
"""
if hasattr(bind.dialect, 'inspector'):
return bind.dialect.inspector(bind)
return Inspector(bind)
@inspection._inspects(Connectable)
def _insp(bind):
return Inspector.from_engine(bind)
@property
def default_schema_name(self):
"""Return the default schema name presented by the dialect
for the current engine's database user.
E.g. this is typically ``public`` for PostgreSQL and ``dbo``
for SQL Server.
"""
return self.dialect.default_schema_name
def get_schema_names(self):
@ -79,70 +154,185 @@ class Inspector(object):
"""
if hasattr(self.dialect, 'get_schema_names'):
return self.dialect.get_schema_names(self.conn,
info_cache=self.info_cache)
return self.dialect.get_schema_names(self.bind,
info_cache=self.info_cache)
return []
def get_table_names(self, schema=None, order_by=None):
"""Return all table names in `schema`.
"""Return all table names in referred to within a particular schema.
The names are expected to be real tables only, not views.
Views are instead returned using the :meth:`.Inspector.get_view_names`
method.
:param schema: Schema name. If ``schema`` is left at ``None``, the
database's default schema is
used, else the named schema is searched. If the database does not
support named schemas, behavior is undefined if ``schema`` is not
passed as ``None``. For special quoting, use :class:`.quoted_name`.
:param schema: Optional, retrieve names from a non-default schema.
:param order_by: Optional, may be the string "foreign_key" to sort
the result on foreign key dependencies.
the result on foreign key dependencies. Does not automatically
resolve cycles, and will raise :class:`.CircularDependencyError`
if cycles exist.
.. deprecated:: 1.0.0 - see
:meth:`.Inspector.get_sorted_table_and_fkc_names` for a version
of this which resolves foreign key cycles between tables
automatically.
.. versionchanged:: 0.8 the "foreign_key" sorting sorts tables
in order of dependee to dependent; that is, in creation
order, rather than in drop order. This is to maintain
consistency with similar features such as
:attr:`.MetaData.sorted_tables` and :func:`.util.sort_tables`.
.. seealso::
:meth:`.Inspector.get_sorted_table_and_fkc_names`
:attr:`.MetaData.sorted_tables`
This should probably not return view names or maybe it should return
them with an indicator t or v.
"""
if hasattr(self.dialect, 'get_table_names'):
tnames = self.dialect.get_table_names(self.conn,
schema,
info_cache=self.info_cache)
tnames = self.dialect.get_table_names(
self.bind, schema, info_cache=self.info_cache)
else:
tnames = self.engine.table_names(schema)
if order_by == 'foreign_key':
ordered_tnames = tnames[:]
# Order based on foreign key dependencies.
tuples = []
for tname in tnames:
table_pos = tnames.index(tname)
fkeys = self.get_foreign_keys(tname, schema)
for fkey in fkeys:
rtable = fkey['referred_table']
if rtable in ordered_tnames:
ref_pos = ordered_tnames.index(rtable)
# Make sure it's lower in the list than anything it
# references.
if table_pos > ref_pos:
ordered_tnames.pop(table_pos) # rtable moves up 1
# insert just below rtable
ordered_tnames.index(ref_pos, tname)
tnames = ordered_tnames
for fkey in self.get_foreign_keys(tname, schema):
if tname != fkey['referred_table']:
tuples.append((fkey['referred_table'], tname))
tnames = list(topological.sort(tuples, tnames))
return tnames
def get_sorted_table_and_fkc_names(self, schema=None):
"""Return dependency-sorted table and foreign key constraint names in
referred to within a particular schema.
This will yield 2-tuples of
``(tablename, [(tname, fkname), (tname, fkname), ...])``
consisting of table names in CREATE order grouped with the foreign key
constraint names that are not detected as belonging to a cycle.
The final element
will be ``(None, [(tname, fkname), (tname, fkname), ..])``
which will consist of remaining
foreign key constraint names that would require a separate CREATE
step after-the-fact, based on dependencies between tables.
.. versionadded:: 1.0.-
.. seealso::
:meth:`.Inspector.get_table_names`
:func:`.sort_tables_and_constraints` - similar method which works
with an already-given :class:`.MetaData`.
"""
if hasattr(self.dialect, 'get_table_names'):
tnames = self.dialect.get_table_names(
self.bind, schema, info_cache=self.info_cache)
else:
tnames = self.engine.table_names(schema)
tuples = set()
remaining_fkcs = set()
fknames_for_table = {}
for tname in tnames:
fkeys = self.get_foreign_keys(tname, schema)
fknames_for_table[tname] = set(
[fk['name'] for fk in fkeys]
)
for fkey in fkeys:
if tname != fkey['referred_table']:
tuples.add((fkey['referred_table'], tname))
try:
candidate_sort = list(topological.sort(tuples, tnames))
except exc.CircularDependencyError as err:
for edge in err.edges:
tuples.remove(edge)
remaining_fkcs.update(
(edge[1], fkc)
for fkc in fknames_for_table[edge[1]]
)
candidate_sort = list(topological.sort(tuples, tnames))
return [
(tname, fknames_for_table[tname].difference(remaining_fkcs))
for tname in candidate_sort
] + [(None, list(remaining_fkcs))]
def get_temp_table_names(self):
"""return a list of temporary table names for the current bind.
This method is unsupported by most dialects; currently
only SQLite implements it.
.. versionadded:: 1.0.0
"""
return self.dialect.get_temp_table_names(
self.bind, info_cache=self.info_cache)
def get_temp_view_names(self):
"""return a list of temporary view names for the current bind.
This method is unsupported by most dialects; currently
only SQLite implements it.
.. versionadded:: 1.0.0
"""
return self.dialect.get_temp_view_names(
self.bind, info_cache=self.info_cache)
def get_table_options(self, table_name, schema=None, **kw):
"""Return a dictionary of options specified when the table of the
given name was created.
This currently includes some options that apply to MySQL tables.
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
"""
if hasattr(self.dialect, 'get_table_options'):
return self.dialect.get_table_options(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
return self.dialect.get_table_options(
self.bind, table_name, schema,
info_cache=self.info_cache, **kw)
return {}
def get_view_names(self, schema=None):
"""Return all view names in `schema`.
:param schema: Optional, retrieve names from a non-default schema.
For special quoting, use :class:`.quoted_name`.
"""
return self.dialect.get_view_names(self.conn, schema,
info_cache=self.info_cache)
return self.dialect.get_view_names(self.bind, schema,
info_cache=self.info_cache)
def get_view_definition(self, view_name, schema=None):
"""Return definition for `view_name`.
:param schema: Optional, retrieve names from a non-default schema.
For special quoting, use :class:`.quoted_name`.
"""
return self.dialect.get_view_definition(
self.conn, view_name, schema, info_cache=self.info_cache)
self.bind, view_name, schema, info_cache=self.info_cache)
def get_columns(self, table_name, schema=None, **kw):
"""Return information about columns in `table_name`.
@ -150,23 +340,31 @@ class Inspector(object):
Given a string `table_name` and an optional string `schema`, return
column information as a list of dicts with these keys:
name
the column's name
* ``name`` - the column's name
type
* ``type`` - the type of this column; an instance of
:class:`~sqlalchemy.types.TypeEngine`
nullable
boolean
* ``nullable`` - boolean flag if the column is NULL or NOT NULL
default
the column's default value
* ``default`` - the column's server default value - this is returned
as a string SQL expression.
* ``attrs`` - dict containing optional column attributes
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
:return: list of dictionaries, each representing the definition of
a database column.
attrs
dict containing optional column attributes
"""
col_defs = self.dialect.get_columns(self.conn, table_name, schema,
col_defs = self.dialect.get_columns(self.bind, table_name, schema,
info_cache=self.info_cache,
**kw)
for col_def in col_defs:
@ -176,6 +374,8 @@ class Inspector(object):
col_def['type'] = coltype()
return col_defs
@deprecated('0.7', 'Call to deprecated method get_primary_keys.'
' Use get_pk_constraint instead.')
def get_primary_keys(self, table_name, schema=None, **kw):
"""Return information about primary keys in `table_name`.
@ -183,12 +383,34 @@ class Inspector(object):
primary key information as a list of column names.
"""
pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema,
return self.dialect.get_pk_constraint(self.bind, table_name, schema,
info_cache=self.info_cache,
**kw)['constrained_columns']
def get_pk_constraint(self, table_name, schema=None, **kw):
"""Return information about primary key constraint on `table_name`.
Given a string `table_name`, and an optional string `schema`, return
primary key information as a dictionary with these keys:
constrained_columns
a list of column names that make up the primary key
name
optional name of the primary key constraint.
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
"""
return self.dialect.get_pk_constraint(self.bind, table_name, schema,
info_cache=self.info_cache,
**kw)
return pkeys
def get_foreign_keys(self, table_name, schema=None, **kw):
"""Return information about foreign_keys in `table_name`.
@ -208,15 +430,21 @@ class Inspector(object):
a list of column names in the referred table that correspond to
constrained_columns
\**kw
other options passed to the dialect's get_foreign_keys() method.
name
optional name of the foreign key constraint.
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
"""
fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
return fk_defs
return self.dialect.get_foreign_keys(self.bind, table_name, schema,
info_cache=self.info_cache,
**kw)
def get_indexes(self, table_name, schema=None, **kw):
"""Return information about indexes in `table_name`.
@ -232,104 +460,261 @@ class Inspector(object):
unique
boolean
\**kw
other options passed to the dialect's get_indexes() method.
dialect_options
dict of dialect-specific index options. May not be present
for all dialects.
.. versionadded:: 1.0.0
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
"""
indexes = self.dialect.get_indexes(self.conn, table_name,
schema,
info_cache=self.info_cache, **kw)
return indexes
return self.dialect.get_indexes(self.bind, table_name,
schema,
info_cache=self.info_cache, **kw)
def reflecttable(self, table, include_columns):
def get_unique_constraints(self, table_name, schema=None, **kw):
"""Return information about unique constraints in `table_name`.
dialect = self.conn.dialect
Given a string `table_name` and an optional string `schema`, return
unique constraint information as a list of dicts with these keys:
# MySQL dialect does this. Applicable with other dialects?
if hasattr(dialect, '_connection_charset') \
and hasattr(dialect, '_adjust_casing'):
charset = dialect._connection_charset
dialect._adjust_casing(table)
name
the unique constraint's name
# table attributes we might need.
reflection_options = dict(
(k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs)
column_names
list of column names in order
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
.. versionadded:: 0.8.4
"""
return self.dialect.get_unique_constraints(
self.bind, table_name, schema, info_cache=self.info_cache, **kw)
def get_check_constraints(self, table_name, schema=None, **kw):
"""Return information about check constraints in `table_name`.
Given a string `table_name` and an optional string `schema`, return
check constraint information as a list of dicts with these keys:
name
the check constraint's name
sqltext
the check constraint's SQL expression
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
.. versionadded:: 1.1.0
"""
return self.dialect.get_check_constraints(
self.bind, table_name, schema, info_cache=self.info_cache, **kw)
def reflecttable(self, table, include_columns, exclude_columns=(),
_extend_on=None):
"""Given a Table object, load its internal constructs based on
introspection.
This is the underlying method used by most dialects to produce
table reflection. Direct usage is like::
from sqlalchemy import create_engine, MetaData, Table
from sqlalchemy.engine import reflection
engine = create_engine('...')
meta = MetaData()
user_table = Table('user', meta)
insp = Inspector.from_engine(engine)
insp.reflecttable(user_table, None)
:param table: a :class:`~sqlalchemy.schema.Table` instance.
:param include_columns: a list of string column names to include
in the reflection process. If ``None``, all columns are reflected.
"""
if _extend_on is not None:
if table in _extend_on:
return
else:
_extend_on.add(table)
dialect = self.bind.dialect
schema = self.bind.schema_for_object(table)
schema = table.schema
table_name = table.name
# apply table options
tbl_opts = self.get_table_options(table_name, schema, **table.kwargs)
# get table-level arguments that are specifically
# intended for reflection, e.g. oracle_resolve_synonyms.
# these are unconditionally passed to related Table
# objects
reflection_options = dict(
(k, table.dialect_kwargs.get(k))
for k in dialect.reflection_options
if k in table.dialect_kwargs
)
# reflect table options, like mysql_engine
tbl_opts = self.get_table_options(
table_name, schema, **table.dialect_kwargs)
if tbl_opts:
table.kwargs.update(tbl_opts)
# add additional kwargs to the Table if the dialect
# returned them
table._validate_dialect_kwargs(tbl_opts)
# table.kwargs will need to be passed to each reflection method. Make
# sure keywords are strings.
tblkw = table.kwargs.copy()
for (k, v) in tblkw.items():
del tblkw[k]
tblkw[str(k)] = v
if util.py2k:
if isinstance(schema, str):
schema = schema.decode(dialect.encoding)
if isinstance(table_name, str):
table_name = table_name.decode(dialect.encoding)
# Py2K
if isinstance(schema, str):
schema = schema.decode(dialect.encoding)
if isinstance(table_name, str):
table_name = table_name.decode(dialect.encoding)
# end Py2K
# columns
found_table = False
for col_d in self.get_columns(table_name, schema, **tblkw):
found_table = True
name = col_d['name']
if include_columns and name not in include_columns:
continue
cols_by_orig_name = {}
coltype = col_d['type']
col_kw = {
'nullable':col_d['nullable'],
}
if 'autoincrement' in col_d:
col_kw['autoincrement'] = col_d['autoincrement']
if 'quote' in col_d:
col_kw['quote'] = col_d['quote']
colargs = []
if col_d.get('default') is not None:
# the "default" value is assumed to be a literal SQL expression,
# so is wrapped in text() so that no quoting occurs on re-issuance.
colargs.append(sa_schema.DefaultClause(sql.text(col_d['default'])))
if 'sequence' in col_d:
# TODO: mssql, maxdb and sybase are using this.
seq = col_d['sequence']
sequence = sa_schema.Sequence(seq['name'], 1, 1)
if 'start' in seq:
sequence.start = seq['start']
if 'increment' in seq:
sequence.increment = seq['increment']
colargs.append(sequence)
col = sa_schema.Column(name, coltype, *colargs, **col_kw)
table.append_column(col)
for col_d in self.get_columns(
table_name, schema, **table.dialect_kwargs):
found_table = True
self._reflect_column(
table, col_d, include_columns,
exclude_columns, cols_by_orig_name)
if not found_table:
raise exc.NoSuchTableError(table.name)
# Primary keys
primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[
table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw)
if pk in table.c
])
self._reflect_pk(
table_name, schema, table, cols_by_orig_name, exclude_columns)
table.append_constraint(primary_key_constraint)
self._reflect_fk(
table_name, schema, table, cols_by_orig_name,
exclude_columns, _extend_on, reflection_options)
# Foreign keys
fkeys = self.get_foreign_keys(table_name, schema, **tblkw)
self._reflect_indexes(
table_name, schema, table, cols_by_orig_name,
include_columns, exclude_columns, reflection_options)
self._reflect_unique_constraints(
table_name, schema, table, cols_by_orig_name,
include_columns, exclude_columns, reflection_options)
self._reflect_check_constraints(
table_name, schema, table, cols_by_orig_name,
include_columns, exclude_columns, reflection_options)
def _reflect_column(
self, table, col_d, include_columns,
exclude_columns, cols_by_orig_name):
orig_name = col_d['name']
table.dispatch.column_reflect(self, table, col_d)
# fetch name again as column_reflect is allowed to
# change it
name = col_d['name']
if (include_columns and name not in include_columns) \
or (exclude_columns and name in exclude_columns):
return
coltype = col_d['type']
col_kw = dict(
(k, col_d[k])
for k in ['nullable', 'autoincrement', 'quote', 'info', 'key']
if k in col_d
)
colargs = []
if col_d.get('default') is not None:
default = col_d['default']
if isinstance(default, sql.elements.TextClause):
default = sa_schema.DefaultClause(default, _reflected=True)
elif not isinstance(default, sa_schema.FetchedValue):
default = sa_schema.DefaultClause(
sql.text(col_d['default']), _reflected=True)
colargs.append(default)
if 'sequence' in col_d:
self._reflect_col_sequence(col_d, colargs)
cols_by_orig_name[orig_name] = col = \
sa_schema.Column(name, coltype, *colargs, **col_kw)
if col.key in table.primary_key:
col.primary_key = True
table.append_column(col)
def _reflect_col_sequence(self, col_d, colargs):
if 'sequence' in col_d:
# TODO: mssql and sybase are using this.
seq = col_d['sequence']
sequence = sa_schema.Sequence(seq['name'], 1, 1)
if 'start' in seq:
sequence.start = seq['start']
if 'increment' in seq:
sequence.increment = seq['increment']
colargs.append(sequence)
def _reflect_pk(
self, table_name, schema, table,
cols_by_orig_name, exclude_columns):
pk_cons = self.get_pk_constraint(
table_name, schema, **table.dialect_kwargs)
if pk_cons:
pk_cols = [
cols_by_orig_name[pk]
for pk in pk_cons['constrained_columns']
if pk in cols_by_orig_name and pk not in exclude_columns
]
# update pk constraint name
table.primary_key.name = pk_cons.get('name')
# tell the PKConstraint to re-initialize
# its column collection
table.primary_key._reload(pk_cols)
def _reflect_fk(
self, table_name, schema, table, cols_by_orig_name,
exclude_columns, _extend_on, reflection_options):
fkeys = self.get_foreign_keys(
table_name, schema, **table.dialect_kwargs)
for fkey_d in fkeys:
conname = fkey_d['name']
constrained_columns = fkey_d['constrained_columns']
# look for columns by orig name in cols_by_orig_name,
# but support columns that are in-Python only as fallback
constrained_columns = [
cols_by_orig_name[c].key
if c in cols_by_orig_name else c
for c in fkey_d['constrained_columns']
]
if exclude_columns and set(constrained_columns).intersection(
exclude_columns):
continue
referred_schema = fkey_d['referred_schema']
referred_table = fkey_d['referred_table']
referred_columns = fkey_d['referred_columns']
@ -337,7 +722,8 @@ class Inspector(object):
if referred_schema is not None:
sa_schema.Table(referred_table, table.metadata,
autoload=True, schema=referred_schema,
autoload_with=self.conn,
autoload_with=self.bind,
_extend_on=_extend_on,
**reflection_options
)
for column in referred_columns:
@ -345,26 +731,113 @@ class Inspector(object):
[referred_schema, referred_table, column]))
else:
sa_schema.Table(referred_table, table.metadata, autoload=True,
autoload_with=self.conn,
autoload_with=self.bind,
schema=sa_schema.BLANK_SCHEMA,
_extend_on=_extend_on,
**reflection_options
)
for column in referred_columns:
refspec.append(".".join([referred_table, column]))
if 'options' in fkey_d:
options = fkey_d['options']
else:
options = {}
table.append_constraint(
sa_schema.ForeignKeyConstraint(constrained_columns, refspec,
conname, link_to_name=True))
# Indexes
conname, link_to_name=True,
**options))
def _reflect_indexes(
self, table_name, schema, table, cols_by_orig_name,
include_columns, exclude_columns, reflection_options):
# Indexes
indexes = self.get_indexes(table_name, schema)
for index_d in indexes:
name = index_d['name']
columns = index_d['column_names']
unique = index_d['unique']
flavor = index_d.get('type', 'unknown type')
flavor = index_d.get('type', 'index')
dialect_options = index_d.get('dialect_options', {})
duplicates = index_d.get('duplicates_constraint')
if include_columns and \
not set(columns).issubset(include_columns):
not set(columns).issubset(include_columns):
util.warn(
"Omitting %s KEY for (%s), key covers omitted columns." %
"Omitting %s key for (%s), key covers omitted columns." %
(flavor, ', '.join(columns)))
continue
sa_schema.Index(name, *[table.columns[c] for c in columns],
**dict(unique=unique))
if duplicates:
continue
# look for columns by orig name in cols_by_orig_name,
# but support columns that are in-Python only as fallback
idx_cols = []
for c in columns:
try:
idx_col = cols_by_orig_name[c] \
if c in cols_by_orig_name else table.c[c]
except KeyError:
util.warn(
"%s key '%s' was not located in "
"columns for table '%s'" % (
flavor, c, table_name
))
else:
idx_cols.append(idx_col)
sa_schema.Index(
name, *idx_cols,
**dict(list(dialect_options.items()) + [('unique', unique)])
)
def _reflect_unique_constraints(
self, table_name, schema, table, cols_by_orig_name,
include_columns, exclude_columns, reflection_options):
# Unique Constraints
try:
constraints = self.get_unique_constraints(table_name, schema)
except NotImplementedError:
# optional dialect feature
return
for const_d in constraints:
conname = const_d['name']
columns = const_d['column_names']
duplicates = const_d.get('duplicates_index')
if include_columns and \
not set(columns).issubset(include_columns):
util.warn(
"Omitting unique constraint key for (%s), "
"key covers omitted columns." %
', '.join(columns))
continue
if duplicates:
continue
# look for columns by orig name in cols_by_orig_name,
# but support columns that are in-Python only as fallback
constrained_cols = []
for c in columns:
try:
constrained_col = cols_by_orig_name[c] \
if c in cols_by_orig_name else table.c[c]
except KeyError:
util.warn(
"unique constraint key '%s' was not located in "
"columns for table '%s'" % (c, table_name))
else:
constrained_cols.append(constrained_col)
table.append_constraint(
sa_schema.UniqueConstraint(*constrained_cols, name=conname))
def _reflect_check_constraints(
self, table_name, schema, table, cols_by_orig_name,
include_columns, exclude_columns, reflection_options):
try:
constraints = self.get_check_constraints(table_name, schema)
except NotImplementedError:
# optional dialect feature
return
for const_d in constraints:
table.append_constraint(
sa_schema.CheckConstraint(**const_d))

View File

@ -1,3 +1,10 @@
# engine/strategies.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Strategies for creating new instances of Engine types.
These are semi-private implementation classes which provide the
@ -11,18 +18,19 @@ New strategies can be added via new ``EngineStrategy`` classes.
from operator import attrgetter
from sqlalchemy.engine import base, threadlocal, url
from sqlalchemy import util, exc
from sqlalchemy import util, event
from sqlalchemy import pool as poollib
from sqlalchemy.sql import schema
strategies = {}
class EngineStrategy(object):
"""An adaptor that processes input arguements and produces an Engine.
"""An adaptor that processes input arguments and produces an Engine.
Provides a ``create`` method that receives input arguments and
produces an instance of base.Engine or a subclass.
"""
def __init__(self):
@ -35,58 +43,75 @@ class EngineStrategy(object):
class DefaultEngineStrategy(EngineStrategy):
"""Base class for built-in stratgies."""
"""Base class for built-in strategies."""
pool_threadlocal = False
def create(self, name_or_url, **kwargs):
# create url.URL object
u = url.make_url(name_or_url)
dialect_cls = u.get_dialect()
plugins = u._instantiate_plugins(kwargs)
u.query.pop('plugin', None)
entrypoint = u._get_entrypoint()
dialect_cls = entrypoint.get_dialect_cls(u)
if kwargs.pop('_coerce_config', False):
def pop_kwarg(key, default=None):
value = kwargs.pop(key, default)
if key in dialect_cls.engine_config_types:
value = dialect_cls.engine_config_types[key](value)
return value
else:
pop_kwarg = kwargs.pop
dialect_args = {}
# consume dialect arguments from kwargs
for k in util.get_cls_kwargs(dialect_cls):
if k in kwargs:
dialect_args[k] = kwargs.pop(k)
dialect_args[k] = pop_kwarg(k)
dbapi = kwargs.pop('module', None)
if dbapi is None:
dbapi_args = {}
for k in util.get_func_kwargs(dialect_cls.dbapi):
if k in kwargs:
dbapi_args[k] = kwargs.pop(k)
dbapi_args[k] = pop_kwarg(k)
dbapi = dialect_cls.dbapi(**dbapi_args)
dialect_args['dbapi'] = dbapi
for plugin in plugins:
plugin.handle_dialect_kwargs(dialect_cls, dialect_args)
# create dialect
dialect = dialect_cls(**dialect_args)
# assemble connection arguments
(cargs, cparams) = dialect.create_connect_args(u)
cparams.update(kwargs.pop('connect_args', {}))
cparams.update(pop_kwarg('connect_args', {}))
cargs = list(cargs) # allow mutability
# look for existing pool or create
pool = kwargs.pop('pool', None)
pool = pop_kwarg('pool', None)
if pool is None:
def connect():
try:
return dialect.connect(*cargs, **cparams)
except Exception, e:
# Py3K
#raise exc.DBAPIError.instance(None, None, e) from e
# Py2K
import sys
raise exc.DBAPIError.instance(None, None, e), None, sys.exc_info()[2]
# end Py2K
creator = kwargs.pop('creator', connect)
def connect(connection_record=None):
if dialect._has_events:
for fn in dialect.dispatch.do_connect:
connection = fn(
dialect, connection_record, cargs, cparams)
if connection is not None:
return connection
return dialect.connect(*cargs, **cparams)
poolclass = (kwargs.pop('poolclass', None) or
getattr(dialect_cls, 'poolclass', poollib.QueuePool))
pool_args = {}
creator = pop_kwarg('creator', connect)
poolclass = pop_kwarg('poolclass', None)
if poolclass is None:
poolclass = dialect_cls.get_pool_class(u)
pool_args = {
'dialect': dialect
}
# consume pool arguments from kwargs, translating a few of
# the arguments
@ -94,12 +119,17 @@ class DefaultEngineStrategy(EngineStrategy):
'echo': 'echo_pool',
'timeout': 'pool_timeout',
'recycle': 'pool_recycle',
'use_threadlocal':'pool_threadlocal'}
'events': 'pool_events',
'use_threadlocal': 'pool_threadlocal',
'reset_on_return': 'pool_reset_on_return'}
for k in util.get_cls_kwargs(poolclass):
tk = translate.get(k, k)
if tk in kwargs:
pool_args[k] = kwargs.pop(tk)
pool_args.setdefault('use_threadlocal', self.pool_threadlocal)
pool_args[k] = pop_kwarg(tk)
for plugin in plugins:
plugin.handle_pool_kwargs(poolclass, pool_args)
pool = poolclass(creator, **pool_args)
else:
if isinstance(pool, poollib._DBProxy):
@ -107,15 +137,17 @@ class DefaultEngineStrategy(EngineStrategy):
else:
pool = pool
pool._dialect = dialect
# create engine.
engineclass = self.engine_cls
engine_args = {}
for k in util.get_cls_kwargs(engineclass):
if k in kwargs:
engine_args[k] = kwargs.pop(k)
engine_args[k] = pop_kwarg(k)
_initialize = kwargs.pop('_initialize', True)
# all kwargs should be consumed
if kwargs:
raise TypeError(
@ -126,24 +158,35 @@ class DefaultEngineStrategy(EngineStrategy):
dialect.__class__.__name__,
pool.__class__.__name__,
engineclass.__name__))
engine = engineclass(pool, dialect, u, **engine_args)
if _initialize:
do_on_connect = dialect.on_connect()
if do_on_connect:
def on_connect(conn, rec):
conn = getattr(conn, '_sqla_unwrap', conn)
def on_connect(dbapi_connection, connection_record):
conn = getattr(
dbapi_connection, '_sqla_unwrap', dbapi_connection)
if conn is None:
return
do_on_connect(conn)
pool.add_listener({'first_connect': on_connect, 'connect':on_connect})
def first_connect(conn, rec):
c = base.Connection(engine, connection=conn)
event.listen(pool, 'first_connect', on_connect)
event.listen(pool, 'connect', on_connect)
def first_connect(dbapi_connection, connection_record):
c = base.Connection(engine, connection=dbapi_connection,
_has_events=False)
c._execution_options = util.immutabledict()
dialect.initialize(c)
pool.add_listener({'first_connect':first_connect})
event.listen(pool, 'first_connect', first_connect, once=True)
dialect_cls.engine_created(engine)
if entrypoint is not dialect_cls:
entrypoint.engine_created(engine)
for plugin in plugins:
plugin.engine_created(engine)
return engine
@ -153,15 +196,14 @@ class PlainEngineStrategy(DefaultEngineStrategy):
name = 'plain'
engine_cls = base.Engine
PlainEngineStrategy()
class ThreadLocalEngineStrategy(DefaultEngineStrategy):
"""Strategy for configuring an Engine with thredlocal behavior."""
"""Strategy for configuring an Engine with threadlocal behavior."""
name = 'threadlocal'
pool_threadlocal = True
engine_cls = threadlocal.TLEngine
ThreadLocalEngineStrategy()
@ -172,11 +214,11 @@ class MockEngineStrategy(EngineStrategy):
Produces a single mock Connectable object which dispatches
statement execution to a passed-in function.
"""
name = 'mock'
def create(self, name_or_url, executor, **kwargs):
# create url.URL object
u = url.make_url(name_or_url)
@ -203,9 +245,14 @@ class MockEngineStrategy(EngineStrategy):
dialect = property(attrgetter('_dialect'))
name = property(lambda s: s._dialect.name)
schema_for_object = schema._schema_getter(None)
def contextual_connect(self, **kwargs):
return self
def execution_options(self, **kw):
return self
def compiler(self, statement, parameters, **kwargs):
return self._dialect.compiler(
statement, parameters, engine=self, **kwargs)
@ -213,13 +260,22 @@ class MockEngineStrategy(EngineStrategy):
def create(self, entity, **kwargs):
kwargs['checkfirst'] = False
from sqlalchemy.engine import ddl
ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity)
ddl.SchemaGenerator(
self.dialect, self, **kwargs).traverse_single(entity)
def drop(self, entity, **kwargs):
kwargs['checkfirst'] = False
from sqlalchemy.engine import ddl
ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity)
ddl.SchemaDropper(
self.dialect, self, **kwargs).traverse_single(entity)
def _run_visitor(self, visitorcallable, element,
connection=None,
**kwargs):
kwargs['checkfirst'] = False
visitorcallable(self.dialect, self,
**kwargs).traverse_single(element)
def execute(self, object, *multiparams, **params):
raise NotImplementedError()

View File

@ -1,23 +1,33 @@
# engine/threadlocal.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Provides a thread-local transactional wrapper around the root Engine class.
The ``threadlocal`` module is invoked when using the ``strategy="threadlocal"`` flag
with :func:`~sqlalchemy.engine.create_engine`. This module is semi-private and is
invoked automatically when the threadlocal engine strategy is used.
The ``threadlocal`` module is invoked when using the
``strategy="threadlocal"`` flag with :func:`~sqlalchemy.engine.create_engine`.
This module is semi-private and is invoked automatically when the threadlocal
engine strategy is used.
"""
from sqlalchemy import util
from sqlalchemy.engine import base
from .. import util
from . import base
import weakref
class TLConnection(base.Connection):
def __init__(self, *arg, **kw):
super(TLConnection, self).__init__(*arg, **kw)
self.__opencount = 0
def _increment_connect(self):
self.__opencount += 1
return self
def close(self):
if self.__opencount == 1:
base.Connection.close(self)
@ -27,70 +37,95 @@ class TLConnection(base.Connection):
self.__opencount = 0
base.Connection.close(self)
class TLEngine(base.Engine):
"""An Engine that includes support for thread-local managed transactions."""
class TLEngine(base.Engine):
"""An Engine that includes support for thread-local managed
transactions.
"""
_tl_connection_cls = TLConnection
def __init__(self, *args, **kwargs):
super(TLEngine, self).__init__(*args, **kwargs)
self._connections = util.threading.local()
proxy = kwargs.get('proxy')
if proxy:
self.TLConnection = base._proxy_connection_cls(TLConnection, proxy)
else:
self.TLConnection = TLConnection
def contextual_connect(self, **kw):
if not hasattr(self._connections, 'conn'):
connection = None
else:
connection = self._connections.conn()
if connection is None or connection.closed:
# guards against pool-level reapers, if desired.
# or not connection.connection.is_valid:
connection = self.TLConnection(self, self.pool.connect(), **kw)
self._connections.conn = conn = weakref.ref(connection)
connection = self._tl_connection_cls(
self,
self._wrap_pool_connect(
self.pool.connect, connection),
**kw)
self._connections.conn = weakref.ref(connection)
return connection._increment_connect()
def begin_twophase(self, xid=None):
if not hasattr(self._connections, 'trans'):
self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid))
self._connections.trans.append(
self.contextual_connect().begin_twophase(xid=xid))
return self
def begin_nested(self):
if not hasattr(self._connections, 'trans'):
self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin_nested())
self._connections.trans.append(
self.contextual_connect().begin_nested())
return self
def begin(self):
if not hasattr(self._connections, 'trans'):
self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin())
return self
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
if type is None:
self.commit()
else:
self.rollback()
def prepare(self):
if not hasattr(self._connections, 'trans') or \
not self._connections.trans:
return
self._connections.trans[-1].prepare()
def commit(self):
if not hasattr(self._connections, 'trans') or \
not self._connections.trans:
return
trans = self._connections.trans.pop(-1)
trans.commit()
def rollback(self):
if not hasattr(self._connections, 'trans') or \
not self._connections.trans:
return
trans = self._connections.trans.pop(-1)
trans.rollback()
def dispose(self):
self._connections = util.threading.local()
super(TLEngine, self).dispose()
@property
def closed(self):
return not hasattr(self._connections, 'conn') or \
self._connections.conn() is None or \
self._connections.conn().closed
self._connections.conn() is None or \
self._connections.conn().closed
def close(self):
if not self.closed:
self.contextual_connect().close()
@ -98,6 +133,6 @@ class TLEngine(base.Engine):
connection._force_close()
del self._connections.conn
self._connections.trans = []
def __repr__(self):
return 'TLEngine(%s)' % str(self.url)
return 'TLEngine(%r)' % self.url

View File

@ -1,13 +1,23 @@
# engine/url.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates
information about a database connection specification.
The URL object is created automatically when :func:`~sqlalchemy.engine.create_engine` is called
with a string argument; alternatively, the URL is a public-facing construct which can
The URL object is created automatically when
:func:`~sqlalchemy.engine.create_engine` is called with a string
argument; alternatively, the URL is a public-facing construct which can
be used directly and is also accepted directly by ``create_engine()``.
"""
import re, cgi, sys, urllib
from sqlalchemy import exc
import re
from .. import exc, util
from . import Dialect
from ..dialects import registry, plugins
class URL(object):
@ -15,8 +25,8 @@ class URL(object):
Represent the components of a URL used to connect to a database.
This object is suitable to be passed directly to a
``create_engine()`` call. The fields of the URL are parsed from a
string by the ``module-level make_url()`` function. the string
:func:`~sqlalchemy.create_engine` call. The fields of the URL are parsed
from a string by the :func:`.make_url` function. the string
format of the URL is an RFC-1738-style string.
All initialization parameters are available as public attributes.
@ -53,25 +63,35 @@ class URL(object):
self.database = database
self.query = query or {}
def __str__(self):
def __to_string__(self, hide_password=True):
s = self.drivername + "://"
if self.username is not None:
s += self.username
s += _rfc_1738_quote(self.username)
if self.password is not None:
s += ':' + urllib.quote_plus(self.password)
s += ':' + ('***' if hide_password
else _rfc_1738_quote(self.password))
s += "@"
if self.host is not None:
s += self.host
if ':' in self.host:
s += "[%s]" % self.host
else:
s += self.host
if self.port is not None:
s += ':' + str(self.port)
if self.database is not None:
s += '/' + self.database
if self.query:
keys = self.query.keys()
keys = list(self.query)
keys.sort()
s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys)
return s
def __str__(self):
return self.__to_string__(hide_password=False)
def __repr__(self):
return self.__to_string__()
def __hash__(self):
return hash(str(self))
@ -85,49 +105,58 @@ class URL(object):
self.database == other.database and \
self.query == other.query
def get_backend_name(self):
if '+' not in self.drivername:
return self.drivername
else:
return self.drivername.split('+')[0]
def get_driver_name(self):
if '+' not in self.drivername:
return self.get_dialect().driver
else:
return self.drivername.split('+')[1]
def _instantiate_plugins(self, kwargs):
plugin_names = util.to_list(self.query.get('plugin', ()))
return [
plugins.load(plugin_name)(self, kwargs)
for plugin_name in plugin_names
]
def _get_entrypoint(self):
"""Return the "entry point" dialect class.
This is normally the dialect itself except in the case when the
returned class implements the get_dialect_cls() method.
"""
if '+' not in self.drivername:
name = self.drivername
else:
name = self.drivername.replace('+', '.')
cls = registry.load(name)
# check for legacy dialects that
# would return a module with 'dialect' as the
# actual class
if hasattr(cls, 'dialect') and \
isinstance(cls.dialect, type) and \
issubclass(cls.dialect, Dialect):
return cls.dialect
else:
return cls
def get_dialect(self):
"""Return the SQLAlchemy database dialect class corresponding
to this URL's driver name.
"""
entrypoint = self._get_entrypoint()
dialect_cls = entrypoint.get_dialect_cls(self)
return dialect_cls
try:
if '+' in self.drivername:
dialect, driver = self.drivername.split('+')
else:
dialect, driver = self.drivername, 'base'
module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
module = getattr(module, dialect)
module = getattr(module, driver)
return module.dialect
except ImportError:
module = self._load_entry_point()
if module is not None:
return module
else:
raise
def _load_entry_point(self):
"""attempt to load this url's dialect from entry points, or return None
if pkg_resources is not installed or there is no matching entry point.
Raise ImportError if the actual load fails.
"""
try:
import pkg_resources
except ImportError:
return None
for res in pkg_resources.iter_entry_points('sqlalchemy.dialects'):
if res.name == self.drivername:
return res.load()
else:
return None
def translate_connect_args(self, names=[], **kw):
"""Translate url attributes into a dictionary of connection arguments.
r"""Translate url attributes into a dictionary of connection arguments.
Returns attributes of this url (`host`, `database`, `username`,
`password`, `port`) as a plain dictionary. The attribute names are
@ -136,8 +165,8 @@ class URL(object):
:param \**kw: Optional, alternate key names for url attributes.
:param names: Deprecated. Same purpose as the keyword-based alternate names,
but correlates the name to the original positionally.
:param names: Deprecated. Same purpose as the keyword-based alternate
names, but correlates the name to the original positionally.
"""
translated = {}
@ -153,6 +182,7 @@ class URL(object):
translated[name] = getattr(self, sname)
return translated
def make_url(name_or_url):
"""Given a string or unicode instance, produce a new URL instance.
@ -160,25 +190,28 @@ def make_url(name_or_url):
existing URL object is passed, just returns the object.
"""
if isinstance(name_or_url, basestring):
if isinstance(name_or_url, util.string_types):
return _parse_rfc1738_args(name_or_url)
else:
return name_or_url
def _parse_rfc1738_args(name):
pattern = re.compile(r'''
(?P<name>[\w\+]+)://
(?:
(?P<username>[^:/]*)
(?::(?P<password>[^/]*))?
(?::(?P<password>.*))?
@)?
(?:
(?P<host>[^/:]*)
(?:
\[(?P<ipv6host>[^/]+)\] |
(?P<ipv4host>[^/:]+)
)?
(?::(?P<port>[^/]*))?
)?
(?:/(?P<database>.*))?
'''
, re.X)
''', re.X)
m = pattern.match(name)
if m is not None:
@ -186,29 +219,43 @@ def _parse_rfc1738_args(name):
if components['database'] is not None:
tokens = components['database'].split('?', 2)
components['database'] = tokens[0]
query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None
# Py2K
if query is not None:
query = (
len(tokens) > 1 and dict(util.parse_qsl(tokens[1]))) or None
if util.py2k and query is not None:
query = dict((k.encode('ascii'), query[k]) for k in query)
# end Py2K
else:
query = None
components['query'] = query
if components['password'] is not None:
components['password'] = urllib.unquote_plus(components['password'])
if components['username'] is not None:
components['username'] = _rfc_1738_unquote(components['username'])
if components['password'] is not None:
components['password'] = _rfc_1738_unquote(components['password'])
ipv4host = components.pop('ipv4host')
ipv6host = components.pop('ipv6host')
components['host'] = ipv4host or ipv6host
name = components.pop('name')
return URL(name, **components)
else:
raise exc.ArgumentError(
"Could not parse rfc1738 URL from string '%s'" % name)
def _rfc_1738_quote(text):
return re.sub(r'[:@/]', lambda m: "%%%X" % ord(m.group(0)), text)
def _rfc_1738_unquote(text):
return util.unquote(text)
def _parse_keyvalue_args(name):
m = re.match( r'(\w+)://(.*)', name)
m = re.match(r'(\w+)://(.*)', name)
if m is not None:
(name, args) = m.group(1, 2)
opts = dict( cgi.parse_qsl( args ) )
opts = dict(util.parse_qsl(args))
return URL(name, *opts)
else:
return None

View File

@ -1,13 +1,15 @@
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
# sqlalchemy/exc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Exceptions used with SQLAlchemy.
The base exception class is SQLAlchemyError. Exceptions which are raised as a
result of DBAPI exceptions are all subclasses of
:class:`~sqlalchemy.exc.DBAPIError`.
The base exception class is :exc:`.SQLAlchemyError`. Exceptions which are
raised as a result of DBAPI exceptions are all subclasses of
:exc:`.DBAPIError`.
"""
@ -24,31 +26,100 @@ class ArgumentError(SQLAlchemyError):
"""
class ObjectNotExecutableError(ArgumentError):
"""Raised when an object is passed to .execute() that can't be
executed as SQL.
.. versionadded:: 1.1
"""
def __init__(self, target):
super(ObjectNotExecutableError, self).__init__(
"Not an executable object: %r" % target
)
class NoSuchModuleError(ArgumentError):
"""Raised when a dynamically-loaded module (usually a database dialect)
of a particular name cannot be located."""
class NoForeignKeysError(ArgumentError):
"""Raised when no foreign keys can be located between two selectables
during a join."""
class AmbiguousForeignKeysError(ArgumentError):
"""Raised when more than one foreign key matching can be located
between two selectables during a join."""
class CircularDependencyError(SQLAlchemyError):
"""Raised by topological sorts when a circular dependency is detected"""
"""Raised by topological sorts when a circular dependency is detected.
There are two scenarios where this error occurs:
* In a Session flush operation, if two objects are mutually dependent
on each other, they can not be inserted or deleted via INSERT or
DELETE statements alone; an UPDATE will be needed to post-associate
or pre-deassociate one of the foreign key constrained values.
The ``post_update`` flag described at :ref:`post_update` can resolve
this cycle.
* In a :attr:`.MetaData.sorted_tables` operation, two :class:`.ForeignKey`
or :class:`.ForeignKeyConstraint` objects mutually refer to each
other. Apply the ``use_alter=True`` flag to one or both,
see :ref:`use_alter`.
"""
def __init__(self, message, cycles, edges, msg=None):
if msg is None:
message += " (%s)" % ", ".join(repr(s) for s in cycles)
else:
message = msg
SQLAlchemyError.__init__(self, message)
self.cycles = cycles
self.edges = edges
def __reduce__(self):
return self.__class__, (None, self.cycles,
self.edges, self.args[0])
class CompileError(SQLAlchemyError):
"""Raised when an error occurs during SQL compilation"""
class UnsupportedCompilationError(CompileError):
"""Raised when an operation is not supported by the given compiler.
.. versionadded:: 0.8.3
"""
def __init__(self, compiler, element_type):
super(UnsupportedCompilationError, self).__init__(
"Compiler %r can't render element of type %s" %
(compiler, element_type))
class IdentifierError(SQLAlchemyError):
"""Raised when a schema name is beyond the max character limit"""
# Moved to orm.exc; compatability definition installed by orm import until 0.6
ConcurrentModificationError = None
class DisconnectionError(SQLAlchemyError):
"""A disconnect is detected on a raw DB-API connection.
This error is raised and consumed internally by a connection pool. It can
be raised by a ``PoolListener`` so that the host pool forces a disconnect.
be raised by the :meth:`.PoolEvents.checkout` event so that the host pool
forces a retry; the exception will be caught three times in a row before
the pool gives up and raises :class:`~sqlalchemy.exc.InvalidRequestError`
regarding the connection attempt.
"""
# Moved to orm.exc; compatability definition installed by orm import until 0.6
FlushError = None
class TimeoutError(SQLAlchemyError):
"""Raised when a connection pool times out on getting a connection."""
@ -60,17 +131,52 @@ class InvalidRequestError(SQLAlchemyError):
"""
class NoInspectionAvailable(InvalidRequestError):
"""A subject passed to :func:`sqlalchemy.inspection.inspect` produced
no context for inspection."""
class ResourceClosedError(InvalidRequestError):
"""An operation was requested from a connection, cursor, or other
object that's in a closed state."""
class NoSuchColumnError(KeyError, InvalidRequestError):
"""A nonexistent column is requested from a ``RowProxy``."""
class NoReferenceError(InvalidRequestError):
"""Raised by ``ForeignKey`` to indicate a reference cannot be resolved."""
class NoReferencedTableError(NoReferenceError):
"""Raised by ``ForeignKey`` when the referred ``Table`` cannot be located."""
"""Raised by ``ForeignKey`` when the referred ``Table`` cannot be
located.
"""
def __init__(self, message, tname):
NoReferenceError.__init__(self, message)
self.table_name = tname
def __reduce__(self):
return self.__class__, (self.args[0], self.table_name)
class NoReferencedColumnError(NoReferenceError):
"""Raised by ``ForeignKey`` when the referred ``Column`` cannot be located."""
"""Raised by ``ForeignKey`` when the referred ``Column`` cannot be
located.
"""
def __init__(self, message, tname, cname):
NoReferenceError.__init__(self, message)
self.table_name = tname
self.column_name = cname
def __reduce__(self):
return self.__class__, (self.args[0], self.table_name,
self.column_name)
class NoSuchTableError(InvalidRequestError):
"""Table does not exist or is not visible to a connection."""
@ -80,70 +186,161 @@ class UnboundExecutionError(InvalidRequestError):
"""SQL was attempted without a database connection to execute it on."""
# Moved to orm.exc; compatability definition installed by orm import until 0.6
class DontWrapMixin(object):
"""A mixin class which, when applied to a user-defined Exception class,
will not be wrapped inside of :exc:`.StatementError` if the error is
emitted within the process of executing a statement.
E.g.::
from sqlalchemy.exc import DontWrapMixin
class MyCustomException(Exception, DontWrapMixin):
pass
class MySpecialType(TypeDecorator):
impl = String
def process_bind_param(self, value, dialect):
if value == 'invalid':
raise MyCustomException("invalid!")
"""
# Moved to orm.exc; compatibility definition installed by orm import until 0.6
UnmappedColumnError = None
class DBAPIError(SQLAlchemyError):
class StatementError(SQLAlchemyError):
"""An error occurred during execution of a SQL statement.
:class:`StatementError` wraps the exception raised
during execution, and features :attr:`.statement`
and :attr:`.params` attributes which supply context regarding
the specifics of the statement which had an issue.
The wrapped exception object is available in
the :attr:`.orig` attribute.
"""
statement = None
"""The string SQL statement being invoked when this exception occurred."""
params = None
"""The parameter list being used when this exception occurred."""
orig = None
"""The DBAPI exception object."""
def __init__(self, message, statement, params, orig):
SQLAlchemyError.__init__(self, message)
self.statement = statement
self.params = params
self.orig = orig
self.detail = []
def add_detail(self, msg):
self.detail.append(msg)
def __reduce__(self):
return self.__class__, (self.args[0], self.statement,
self.params, self.orig)
def __str__(self):
from sqlalchemy.sql import util
details = [SQLAlchemyError.__str__(self)]
if self.statement:
details.append("[SQL: %r]" % self.statement)
if self.params:
params_repr = util._repr_params(self.params, 10)
details.append("[parameters: %r]" % params_repr)
return ' '.join([
"(%s)" % det for det in self.detail
] + details)
def __unicode__(self):
return self.__str__()
class DBAPIError(StatementError):
"""Raised when the execution of a database operation fails.
``DBAPIError`` wraps exceptions raised by the DB-API underlying the
Wraps exceptions raised by the DB-API underlying the
database operation. Driver-specific implementations of the standard
DB-API exception types are wrapped by matching sub-types of SQLAlchemy's
``DBAPIError`` when possible. DB-API's ``Error`` type maps to
``DBAPIError`` in SQLAlchemy, otherwise the names are identical. Note
:class:`DBAPIError` when possible. DB-API's ``Error`` type maps to
:class:`DBAPIError` in SQLAlchemy, otherwise the names are identical. Note
that there is no guarantee that different DB-API implementations will
raise the same exception type for any given error condition.
If the error-raising operation occured in the execution of a SQL
statement, that statement and its parameters will be available on
the exception object in the ``statement`` and ``params`` attributes.
:class:`DBAPIError` features :attr:`~.StatementError.statement`
and :attr:`~.StatementError.params` attributes which supply context
regarding the specifics of the statement which had an issue, for the
typical case when the error was raised within the context of
emitting a SQL statement.
The wrapped exception object is available in the ``orig`` attribute.
Its type and properties are DB-API implementation specific.
The wrapped exception object is available in the
:attr:`~.StatementError.orig` attribute. Its type and properties are
DB-API implementation specific.
"""
@classmethod
def instance(cls, statement, params, orig, connection_invalidated=False):
def instance(cls, statement, params,
orig, dbapi_base_err,
connection_invalidated=False,
dialect=None):
# Don't ever wrap these, just return them directly as if
# DBAPIError didn't exist.
if isinstance(orig, (KeyboardInterrupt, SystemExit)):
if (isinstance(orig, BaseException) and
not isinstance(orig, Exception)) or \
isinstance(orig, DontWrapMixin):
return orig
if orig is not None:
name, glob = orig.__class__.__name__, globals()
if name in glob and issubclass(glob[name], DBAPIError):
cls = glob[name]
# not a DBAPI error, statement is present.
# raise a StatementError
if not isinstance(orig, dbapi_base_err) and statement:
return StatementError(
"(%s.%s) %s" %
(orig.__class__.__module__, orig.__class__.__name__,
orig),
statement, params, orig
)
glob = globals()
for super_ in orig.__class__.__mro__:
name = super_.__name__
if dialect:
name = dialect.dbapi_exception_translation_map.get(
name, name)
if name in glob and issubclass(glob[name], DBAPIError):
cls = glob[name]
break
return cls(statement, params, orig, connection_invalidated)
def __reduce__(self):
return self.__class__, (self.statement, self.params,
self.orig, self.connection_invalidated)
def __init__(self, statement, params, orig, connection_invalidated=False):
try:
text = str(orig)
except (KeyboardInterrupt, SystemExit):
raise
except Exception, e:
except Exception as e:
text = 'Error in str() of DB-API-generated exception: ' + str(e)
SQLAlchemyError.__init__(
self, '(%s) %s' % (orig.__class__.__name__, text))
self.statement = statement
self.params = params
self.orig = orig
StatementError.__init__(
self,
'(%s.%s) %s' % (
orig.__class__.__module__, orig.__class__.__name__, text, ),
statement,
params,
orig
)
self.connection_invalidated = connection_invalidated
def __str__(self):
if isinstance(self.params, (list, tuple)) and len(self.params) > 10 and isinstance(self.params[0], (list, dict, tuple)):
return ' '.join((SQLAlchemyError.__str__(self),
repr(self.statement),
repr(self.params[:2]),
'... and a total of %i bound parameter sets' % len(self.params)))
return ' '.join((SQLAlchemyError.__str__(self),
repr(self.statement), repr(self.params)))
# As of 0.4, SQLError is now DBAPIError.
# SQLError alias will be removed in 0.6.
SQLError = DBAPIError
class InterfaceError(DBAPIError):
"""Wraps a DB-API InterfaceError."""

View File

@ -1 +1,11 @@
# ext/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .. import util as _sa_util
_sa_util.dependencies.resolve_all("sqlalchemy.ext")

View File

@ -1,3 +1,10 @@
# ext/associationproxy.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Contain the ``AssociationProxy`` class.
The ``AssociationProxy`` is a Python property object which provides
@ -9,43 +16,37 @@ See the example ``examples/association/proxied_association.py``.
import itertools
import operator
import weakref
from sqlalchemy import exceptions
from sqlalchemy import orm
from sqlalchemy import util
from sqlalchemy.orm import collections
from sqlalchemy.sql import not_
from .. import exc, orm, util
from ..orm import collections, interfaces
from ..sql import not_, or_
def association_proxy(target_collection, attr, **kw):
"""Return a Python property implementing a view of *attr* over a collection.
r"""Return a Python property implementing a view of a target
attribute which references an attribute on members of the
target.
Implements a read/write view over an instance's *target_collection*,
extracting *attr* from each member of the collection. The property acts
somewhat like this list comprehension::
The returned value is an instance of :class:`.AssociationProxy`.
[getattr(member, *attr*)
for member in getattr(instance, *target_collection*)]
Implements a Python property representing a relationship as a collection
of simpler values, or a scalar value. The proxied property will mimic
the collection type of the target (list, dict or set), or, in the case of
a one to one relationship, a simple scalar value.
Unlike the list comprehension, the collection returned by the property is
always in sync with *target_collection*, and mutations made to either
collection will be reflected in both.
:param target_collection: Name of the attribute we'll proxy to.
This attribute is typically mapped by
:func:`~sqlalchemy.orm.relationship` to link to a target collection, but
can also be a many-to-one or non-scalar relationship.
Implements a Python property representing a relationship as a collection of
simpler values. The proxied property will mimic the collection type of
the target (list, dict or set), or, in the case of a one to one relationship,
a simple scalar value.
:param target_collection: Name of the relationship attribute we'll proxy to,
usually created with :func:`~sqlalchemy.orm.relationship`.
:param attr: Attribute on the associated instances we'll proxy for.
:param attr: Attribute on the associated instance or instances we'll
proxy for.
For example, given a target collection of [obj1, obj2], a list created
by this proxy property would look like [getattr(obj1, *attr*),
getattr(obj2, *attr*)]
If the relationship is one-to-one or otherwise uselist=False, then simply:
getattr(obj, *attr*)
If the relationship is one-to-one or otherwise uselist=False, then
simply: getattr(obj, *attr*)
:param creator: optional.
@ -69,59 +70,78 @@ def association_proxy(target_collection, attr, **kw):
situation.
:param \*\*kw: Passes along any other keyword arguments to
:class:`AssociationProxy`.
:class:`.AssociationProxy`.
"""
return AssociationProxy(target_collection, attr, **kw)
class AssociationProxy(object):
ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY')
"""Symbol indicating an :class:`InspectionAttr` that's
of type :class:`.AssociationProxy`.
Is assigned to the :attr:`.InspectionAttr.extension_type`
attibute.
"""
class AssociationProxy(interfaces.InspectionAttrInfo):
"""A descriptor that presents a read/write view of an object attribute."""
is_attribute = False
extension_type = ASSOCIATION_PROXY
def __init__(self, target_collection, attr, creator=None,
getset_factory=None, proxy_factory=None, proxy_bulk_set=None):
"""Arguments are:
getset_factory=None, proxy_factory=None,
proxy_bulk_set=None, info=None):
"""Construct a new :class:`.AssociationProxy`.
target_collection
Name of the collection we'll proxy to, usually created with
'relationship()' in a mapper setup.
The :func:`.association_proxy` function is provided as the usual
entrypoint here, though :class:`.AssociationProxy` can be instantiated
and/or subclassed directly.
attr
Attribute on the collected instances we'll proxy for. For example,
given a target collection of [obj1, obj2], a list created by this
proxy property would look like [getattr(obj1, attr), getattr(obj2,
attr)]
:param target_collection: Name of the collection we'll proxy to,
usually created with :func:`.relationship`.
creator
Optional. When new items are added to this proxied collection, new
instances of the class collected by the target collection will be
created. For list and set collections, the target class constructor
will be called with the 'value' for the new instance. For dict
types, two arguments are passed: key and value.
:param attr: Attribute on the collected instances we'll proxy
for. For example, given a target collection of [obj1, obj2], a
list created by this proxy property would look like
[getattr(obj1, attr), getattr(obj2, attr)]
:param creator: Optional. When new items are added to this proxied
collection, new instances of the class collected by the target
collection will be created. For list and set collections, the
target class constructor will be called with the 'value' for the
new instance. For dict types, two arguments are passed:
key and value.
If you want to construct instances differently, supply a 'creator'
function that takes arguments as above and returns instances.
getset_factory
Optional. Proxied attribute access is automatically handled by
routines that get and set values based on the `attr` argument for
this proxy.
:param getset_factory: Optional. Proxied attribute access is
automatically handled by routines that get and set values based on
the `attr` argument for this proxy.
If you would like to customize this behavior, you may supply a
`getset_factory` callable that produces a tuple of `getter` and
`setter` functions. The factory is called with two arguments, the
abstract type of the underlying collection and this proxy instance.
proxy_factory
Optional. The type of collection to emulate is determined by
sniffing the target collection. If your collection type can't be
determined by duck typing or you'd like to use a different
collection implementation, you may supply a factory function to
produce those collections. Only applicable to non-scalar relationships.
:param proxy_factory: Optional. The type of collection to emulate is
determined by sniffing the target collection. If your collection
type can't be determined by duck typing or you'd like to use a
different collection implementation, you may supply a factory
function to produce those collections. Only applicable to
non-scalar relationships.
proxy_bulk_set
Optional, use with proxy_factory. See the _set() method for
details.
:param proxy_bulk_set: Optional, use with proxy_factory. See
the _set() method for details.
:param info: optional, will be assigned to
:attr:`.AssociationProxy.info` if present.
.. versionadded:: 1.0.9
"""
self.target_collection = target_collection
@ -131,36 +151,107 @@ class AssociationProxy(object):
self.proxy_factory = proxy_factory
self.proxy_bulk_set = proxy_bulk_set
self.scalar = None
self.owning_class = None
self.key = '_%s_%s_%s' % (
type(self).__name__, target_collection, id(self))
self.collection_class = None
if info:
self.info = info
@property
def remote_attr(self):
"""The 'remote' :class:`.MapperProperty` referenced by this
:class:`.AssociationProxy`.
.. versionadded:: 0.7.3
See also:
:attr:`.AssociationProxy.attr`
:attr:`.AssociationProxy.local_attr`
"""
return getattr(self.target_class, self.value_attr)
@property
def local_attr(self):
"""The 'local' :class:`.MapperProperty` referenced by this
:class:`.AssociationProxy`.
.. versionadded:: 0.7.3
See also:
:attr:`.AssociationProxy.attr`
:attr:`.AssociationProxy.remote_attr`
"""
return getattr(self.owning_class, self.target_collection)
@property
def attr(self):
"""Return a tuple of ``(local_attr, remote_attr)``.
This attribute is convenient when specifying a join
using :meth:`.Query.join` across two relationships::
sess.query(Parent).join(*Parent.proxied.attr)
.. versionadded:: 0.7.3
See also:
:attr:`.AssociationProxy.local_attr`
:attr:`.AssociationProxy.remote_attr`
"""
return (self.local_attr, self.remote_attr)
def _get_property(self):
return (orm.class_mapper(self.owning_class).
get_property(self.target_collection))
@property
@util.memoized_property
def target_class(self):
"""The class the proxy is attached to."""
"""The intermediary class handled by this :class:`.AssociationProxy`.
Intercepted append/set/assignment events will result
in the generation of new instances of this class.
"""
return self._get_property().mapper.class_
def _target_is_scalar(self):
return not self._get_property().uselist
@util.memoized_property
def scalar(self):
"""Return ``True`` if this :class:`.AssociationProxy` proxies a scalar
relationship on the local side."""
scalar = not self._get_property().uselist
if scalar:
self._initialize_scalar_accessors()
return scalar
@util.memoized_property
def _value_is_scalar(self):
return not self._get_property().\
mapper.get_property(self.value_attr).uselist
@util.memoized_property
def _target_is_object(self):
return getattr(self.target_class, self.value_attr).impl.uses_objects
def __get__(self, obj, class_):
if self.owning_class is None:
self.owning_class = class_ and class_ or type(obj)
if obj is None:
return self
elif self.scalar is None:
self.scalar = self._target_is_scalar()
if self.scalar:
self._initialize_scalar_accessors()
if self.scalar:
return self._scalar_get(getattr(obj, self.target_collection))
target = getattr(obj, self.target_collection)
return self._scalar_get(target)
else:
try:
# If the owning instance is reborn (orm session resurrect,
@ -173,14 +264,10 @@ class AssociationProxy(object):
proxy = self._new(_lazy_collection(obj, self.target_collection))
setattr(obj, self.key, (id(obj), proxy))
return proxy
def __set__(self, obj, values):
if self.owning_class is None:
self.owning_class = type(obj)
if self.scalar is None:
self.scalar = self._target_is_scalar()
if self.scalar:
self._initialize_scalar_accessors()
if self.scalar:
creator = self.creator and self.creator or self.target_class
@ -209,7 +296,8 @@ class AssociationProxy(object):
def _default_getset(self, collection_class):
attr = self.value_attr
getter = operator.attrgetter(attr)
_getter = operator.attrgetter(attr)
getter = lambda target: _getter(target) if target is not None else None
if collection_class is dict:
setter = lambda o, k, v: setattr(o, attr, v)
else:
@ -221,21 +309,25 @@ class AssociationProxy(object):
self.collection_class = util.duck_type_collection(lazy_collection())
if self.proxy_factory:
return self.proxy_factory(lazy_collection, creator, self.value_attr, self)
return self.proxy_factory(
lazy_collection, creator, self.value_attr, self)
if self.getset_factory:
getter, setter = self.getset_factory(self.collection_class, self)
else:
getter, setter = self._default_getset(self.collection_class)
if self.collection_class is list:
return _AssociationList(lazy_collection, creator, getter, setter, self)
return _AssociationList(
lazy_collection, creator, getter, setter, self)
elif self.collection_class is dict:
return _AssociationDict(lazy_collection, creator, getter, setter, self)
return _AssociationDict(
lazy_collection, creator, getter, setter, self)
elif self.collection_class is set:
return _AssociationSet(lazy_collection, creator, getter, setter, self)
return _AssociationSet(
lazy_collection, creator, getter, setter, self)
else:
raise exceptions.ArgumentError(
raise exc.ArgumentError(
'could not guess which interface to use for '
'collection_class "%s" backing "%s"; specify a '
'proxy_factory and proxy_bulk_set manually' %
@ -248,7 +340,7 @@ class AssociationProxy(object):
getter, setter = self.getset_factory(self.collection_class, self)
else:
getter, setter = self._default_getset(self.collection_class)
proxy.creator = creator
proxy.getter = getter
proxy.setter = setter
@ -263,28 +355,102 @@ class AssociationProxy(object):
elif self.collection_class is set:
proxy.update(values)
else:
raise exceptions.ArgumentError(
'no proxy_bulk_set supplied for custom '
'collection_class implementation')
raise exc.ArgumentError(
'no proxy_bulk_set supplied for custom '
'collection_class implementation')
@property
def _comparator(self):
return self._get_property().comparator
def any(self, criterion=None, **kwargs):
return self._comparator.any(getattr(self.target_class, self.value_attr).has(criterion, **kwargs))
"""Produce a proxied 'any' expression using EXISTS.
This expression will be a composed product
using the :meth:`.RelationshipProperty.Comparator.any`
and/or :meth:`.RelationshipProperty.Comparator.has`
operators of the underlying proxied attributes.
"""
if self._target_is_object:
if self._value_is_scalar:
value_expr = getattr(
self.target_class, self.value_attr).has(
criterion, **kwargs)
else:
value_expr = getattr(
self.target_class, self.value_attr).any(
criterion, **kwargs)
else:
value_expr = criterion
# check _value_is_scalar here, otherwise
# we're scalar->scalar - call .any() so that
# the "can't call any() on a scalar" msg is raised.
if self.scalar and not self._value_is_scalar:
return self._comparator.has(
value_expr
)
else:
return self._comparator.any(
value_expr
)
def has(self, criterion=None, **kwargs):
return self._comparator.has(getattr(self.target_class, self.value_attr).has(criterion, **kwargs))
"""Produce a proxied 'has' expression using EXISTS.
This expression will be a composed product
using the :meth:`.RelationshipProperty.Comparator.any`
and/or :meth:`.RelationshipProperty.Comparator.has`
operators of the underlying proxied attributes.
"""
if self._target_is_object:
return self._comparator.has(
getattr(self.target_class, self.value_attr).
has(criterion, **kwargs)
)
else:
if criterion is not None or kwargs:
raise exc.ArgumentError(
"Non-empty has() not allowed for "
"column-targeted association proxy; use ==")
return self._comparator.has()
def contains(self, obj):
return self._comparator.any(**{self.value_attr: obj})
"""Produce a proxied 'contains' expression using EXISTS.
This expression will be a composed product
using the :meth:`.RelationshipProperty.Comparator.any`
, :meth:`.RelationshipProperty.Comparator.has`,
and/or :meth:`.RelationshipProperty.Comparator.contains`
operators of the underlying proxied attributes.
"""
if self.scalar and not self._value_is_scalar:
return self._comparator.has(
getattr(self.target_class, self.value_attr).contains(obj)
)
else:
return self._comparator.any(**{self.value_attr: obj})
def __eq__(self, obj):
return self._comparator.has(**{self.value_attr: obj})
# note the has() here will fail for collections; eq_()
# is only allowed with a scalar.
if obj is None:
return or_(
self._comparator.has(**{self.value_attr: obj}),
self._comparator == None
)
else:
return self._comparator.has(**{self.value_attr: obj})
def __ne__(self, obj):
return not_(self.__eq__(obj))
# note the has() here will fail for collections; eq_()
# is only allowed with a scalar.
return self._comparator.has(
getattr(self.target_class, self.value_attr) != obj)
class _lazy_collection(object):
@ -295,22 +461,23 @@ class _lazy_collection(object):
def __call__(self):
obj = self.ref()
if obj is None:
raise exceptions.InvalidRequestError(
"stale association proxy, parent object has gone out of "
"scope")
raise exc.InvalidRequestError(
"stale association proxy, parent object has gone out of "
"scope")
return getattr(obj, self.target)
def __getstate__(self):
return {'obj':self.ref(), 'target':self.target}
return {'obj': self.ref(), 'target': self.target}
def __setstate__(self, state):
self.ref = weakref.ref(state['obj'])
self.target = state['target']
class _AssociationCollection(object):
def __init__(self, lazy_collection, creator, getter, setter, parent):
"""Constructs an _AssociationCollection.
"""Constructs an _AssociationCollection.
This will always be a subclass of either _AssociationList,
_AssociationSet, or _AssociationDict.
@ -344,17 +511,20 @@ class _AssociationCollection(object):
def __len__(self):
return len(self.col)
def __nonzero__(self):
def __bool__(self):
return bool(self.col)
__nonzero__ = __bool__
def __getstate__(self):
return {'parent':self.parent, 'lazy_collection':self.lazy_collection}
return {'parent': self.parent, 'lazy_collection': self.lazy_collection}
def __setstate__(self, state):
self.parent = state['parent']
self.lazy_collection = state['lazy_collection']
self.parent._inflate(self)
class _AssociationList(_AssociationCollection):
"""Generic, converting, list-to-list proxy."""
@ -368,7 +538,10 @@ class _AssociationList(_AssociationCollection):
return self.setter(object, value)
def __getitem__(self, index):
return self._get(self.col[index])
if not isinstance(index, slice):
return self._get(self.col[index])
else:
return [self._get(member) for member in self.col[index]]
def __setitem__(self, index, value):
if not isinstance(index, slice):
@ -382,11 +555,12 @@ class _AssociationList(_AssociationCollection):
stop = index.stop
step = index.step or 1
rng = range(index.start or 0, stop, step)
start = index.start or 0
rng = list(range(index.start or 0, stop, step))
if step == 1:
for i in rng:
del self[index.start]
i = index.start
del self[start]
i = start
for item in value:
self.insert(i, item)
i += 1
@ -429,7 +603,7 @@ class _AssociationList(_AssociationCollection):
for member in self.col:
yield self._get(member)
raise StopIteration
return
def append(self, value):
item = self._create(value)
@ -437,7 +611,7 @@ class _AssociationList(_AssociationCollection):
def count(self, value):
return sum([1 for _ in
itertools.ifilter(lambda v: v == value, iter(self))])
util.itertools_filter(lambda v: v == value, iter(self))])
def extend(self, values):
for v in values:
@ -536,14 +710,16 @@ class _AssociationList(_AssociationCollection):
def __hash__(self):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in locals().items():
if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(list, func_name)):
for func_name, func in list(locals().items()):
if (util.callable(func) and func.__name__ == func_name and
not func.__doc__ and hasattr(list, func_name)):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
_NotProvided = util.symbol('_NotProvided')
class _AssociationDict(_AssociationCollection):
"""Generic, converting, dict-to-dict proxy."""
@ -577,7 +753,7 @@ class _AssociationDict(_AssociationCollection):
return key in self.col
def __iter__(self):
return self.col.iterkeys()
return iter(self.col.keys())
def clear(self):
self.col.clear()
@ -622,24 +798,27 @@ class _AssociationDict(_AssociationCollection):
def keys(self):
return self.col.keys()
def iterkeys(self):
return self.col.iterkeys()
if util.py2k:
def iteritems(self):
return ((key, self._get(self.col[key])) for key in self.col)
def values(self):
return [ self._get(member) for member in self.col.values() ]
def itervalues(self):
return (self._get(self.col[key]) for key in self.col)
def itervalues(self):
for key in self.col:
yield self._get(self.col[key])
raise StopIteration
def iterkeys(self):
return self.col.iterkeys()
def items(self):
return [(k, self._get(self.col[k])) for k in self]
def values(self):
return [self._get(member) for member in self.col.values()]
def iteritems(self):
for key in self.col:
yield (key, self._get(self.col[key]))
raise StopIteration
def items(self):
return [(k, self._get(self.col[k])) for k in self]
else:
def items(self):
return ((key, self._get(self.col[key])) for key in self.col)
def values(self):
return (self._get(self.col[key]) for key in self.col)
def pop(self, key, default=_NotProvided):
if default is _NotProvided:
@ -658,11 +837,20 @@ class _AssociationDict(_AssociationCollection):
len(a))
elif len(a) == 1:
seq_or_map = a[0]
for item in seq_or_map:
if isinstance(item, tuple):
self[item[0]] = item[1]
else:
# discern dict from sequence - took the advice from
# http://www.voidspace.org.uk/python/articles/duck_typing.shtml
# still not perfect :(
if hasattr(seq_or_map, 'keys'):
for item in seq_or_map:
self[item] = seq_or_map[item]
else:
try:
for k, v in seq_or_map:
self[k] = v
except ValueError:
raise ValueError(
"dictionary update sequence "
"requires 2-element tuples")
for key, value in kw:
self[key] = value
@ -673,9 +861,9 @@ class _AssociationDict(_AssociationCollection):
def __hash__(self):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in locals().items():
if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(dict, func_name)):
for func_name, func in list(locals().items()):
if (util.callable(func) and func.__name__ == func_name and
not func.__doc__ and hasattr(dict, func_name)):
func.__doc__ = getattr(dict, func_name).__doc__
del func_name, func
@ -695,12 +883,14 @@ class _AssociationSet(_AssociationCollection):
def __len__(self):
return len(self.col)
def __nonzero__(self):
def __bool__(self):
if self.col:
return True
else:
return False
__nonzero__ = __bool__
def __contains__(self, value):
for member in self.col:
# testlib.pragma exempt:__eq__
@ -717,7 +907,7 @@ class _AssociationSet(_AssociationCollection):
"""
for member in self.col:
yield self._get(member)
raise StopIteration
return
def add(self, value):
if value not in self:
@ -871,8 +1061,8 @@ class _AssociationSet(_AssociationCollection):
def __hash__(self):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in locals().items():
if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(set, func_name)):
for func_name, func in list(locals().items()):
if (util.callable(func) and func.__name__ == func_name and
not func.__doc__ and hasattr(set, func_name)):
func.__doc__ = getattr(set, func_name).__doc__
del func_name, func

View File

@ -1,31 +1,39 @@
"""Provides an API for creation of custom ClauseElements and compilers.
# ext/compiler.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
r"""Provides an API for creation of custom ClauseElements and compilers.
Synopsis
========
Usage involves the creation of one or more :class:`~sqlalchemy.sql.expression.ClauseElement`
subclasses and one or more callables defining its compilation::
Usage involves the creation of one or more
:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or
more callables defining its compilation::
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import ColumnClause
class MyColumn(ColumnClause):
pass
@compiles(MyColumn)
def compile_mycolumn(element, compiler, **kw):
return "[%s]" % element.name
Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`,
the base expression element for named column objects. The ``compiles``
decorator registers itself with the ``MyColumn`` class so that it is invoked
when the object is compiled to a string::
from sqlalchemy import select
s = select([MyColumn('x'), MyColumn('y')])
print str(s)
Produces::
SELECT [x], [y]
@ -50,22 +58,25 @@ invoked for the dialect in use::
@compiles(AlterColumn, 'postgresql')
def visit_alter_column(element, compiler, **kw):
return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, element.column.name)
return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name,
element.column.name)
The second ``visit_alter_table`` will be invoked when any ``postgresql`` dialect is used.
The second ``visit_alter_table`` will be invoked when any ``postgresql``
dialect is used.
Compiling sub-elements of a custom expression construct
=======================================================
The ``compiler`` argument is the :class:`~sqlalchemy.engine.base.Compiled`
object in use. This object can be inspected for any information about the
in-progress compilation, including ``compiler.dialect``,
``compiler.statement`` etc. The :class:`~sqlalchemy.sql.compiler.SQLCompiler`
and :class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()``
The ``compiler`` argument is the
:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object
can be inspected for any information about the in-progress compilation,
including ``compiler.dialect``, ``compiler.statement`` etc. The
:class:`~sqlalchemy.sql.compiler.SQLCompiler` and
:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()``
method which can be used for compilation of embedded attributes::
from sqlalchemy.sql.expression import Executable, ClauseElement
class InsertFromSelect(Executable, ClauseElement):
def __init__(self, table, select):
self.table = table
@ -80,36 +91,110 @@ method which can be used for compilation of embedded attributes::
insert = InsertFromSelect(t1, select([t1]).where(t1.c.x>5))
print insert
Produces::
"INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)"
"INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z
FROM mytable WHERE mytable.x > :x_1)"
.. note::
The above ``InsertFromSelect`` construct is only an example, this actual
functionality is already available using the
:meth:`.Insert.from_select` method.
.. note::
The above ``InsertFromSelect`` construct probably wants to have "autocommit"
enabled. See :ref:`enabling_compiled_autocommit` for this step.
Cross Compiling between SQL and DDL compilers
---------------------------------------------
SQL and DDL constructs are each compiled using different base compilers - ``SQLCompiler``
and ``DDLCompiler``. A common need is to access the compilation rules of SQL expressions
from within a DDL expression. The ``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as below where we generate a CHECK
constraint that embeds a SQL expression::
SQL and DDL constructs are each compiled using different base compilers -
``SQLCompiler`` and ``DDLCompiler``. A common need is to access the
compilation rules of SQL expressions from within a DDL expression. The
``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as
below where we generate a CHECK constraint that embeds a SQL expression::
@compiles(MyConstraint)
def compile_my_constraint(constraint, ddlcompiler, **kw):
return "CONSTRAINT %s CHECK (%s)" % (
constraint.name,
ddlcompiler.sql_compiler.process(constraint.expression)
ddlcompiler.sql_compiler.process(
constraint.expression, literal_binds=True)
)
Above, we add an additional flag to the process step as called by
:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This
indicates that any SQL expression which refers to a :class:`.BindParameter`
object or other "literal" object such as those which refer to strings or
integers should be rendered **in-place**, rather than being referred to as
a bound parameter; when emitting DDL, bound parameters are typically not
supported.
.. _enabling_compiled_autocommit:
Enabling Autocommit on a Construct
==================================
Recall from the section :ref:`autocommit` that the :class:`.Engine`, when
asked to execute a construct in the absence of a user-defined transaction,
detects if the given construct represents DML or DDL, that is, a data
modification or data definition statement, which requires (or may require,
in the case of DDL) that the transaction generated by the DBAPI be committed
(recall that DBAPI always has a transaction going on regardless of what
SQLAlchemy does). Checking for this is actually accomplished by checking for
the "autocommit" execution option on the construct. When building a
construct like an INSERT derivation, a new DDL type, or perhaps a stored
procedure that alters data, the "autocommit" option needs to be set in order
for the statement to function with "connectionless" execution
(as described in :ref:`dbengine_implicit`).
Currently a quick way to do this is to subclass :class:`.Executable`, then
add the "autocommit" flag to the ``_execution_options`` dictionary (note this
is a "frozen" dictionary which supplies a generative ``union()`` method)::
from sqlalchemy.sql.expression import Executable, ClauseElement
class MyInsertThing(Executable, ClauseElement):
_execution_options = \
Executable._execution_options.union({'autocommit': True})
More succinctly, if the construct is truly similar to an INSERT, UPDATE, or
DELETE, :class:`.UpdateBase` can be used, which already is a subclass
of :class:`.Executable`, :class:`.ClauseElement` and includes the
``autocommit`` flag::
from sqlalchemy.sql.expression import UpdateBase
class MyInsertThing(UpdateBase):
def __init__(self, ...):
...
DDL elements that subclass :class:`.DDLElement` already have the
"autocommit" flag turned on.
Changing the default compilation of existing constructs
=======================================================
The compiler extension applies just as well to the existing constructs. When overriding
the compilation of a built in SQL construct, the @compiles decorator is invoked upon
the appropriate class (be sure to use the class, i.e. ``Insert`` or ``Select``, instead of the creation function such as ``insert()`` or ``select()``).
The compiler extension applies just as well to the existing constructs. When
overriding the compilation of a built in SQL construct, the @compiles
decorator is invoked upon the appropriate class (be sure to use the class,
i.e. ``Insert`` or ``Select``, instead of the creation function such
as ``insert()`` or ``select()``).
Within the new compilation function, to get at the "original" compilation routine,
use the appropriate visit_XXX method - this because compiler.process() will call upon the
overriding routine and cause an endless loop. Such as, to add "prefix" to all insert statements::
Within the new compilation function, to get at the "original" compilation
routine, use the appropriate visit_XXX method - this
because compiler.process() will call upon the overriding routine and cause
an endless loop. Such as, to add "prefix" to all insert statements::
from sqlalchemy.sql.expression import Insert
@ -117,38 +202,77 @@ overriding routine and cause an endless loop. Such as, to add "prefix" to all
def prefix_inserts(insert, compiler, **kw):
return compiler.visit_insert(insert.prefix_with("some prefix"), **kw)
The above compiler will prefix all INSERT statements with "some prefix" when compiled.
The above compiler will prefix all INSERT statements with "some prefix" when
compiled.
.. _type_compilation_extension:
Changing Compilation of Types
=============================
``compiler`` works for types, too, such as below where we implement the
MS-SQL specific 'max' keyword for ``String``/``VARCHAR``::
@compiles(String, 'mssql')
@compiles(VARCHAR, 'mssql')
def compile_varchar(element, compiler, **kw):
if element.length == 'max':
return "VARCHAR('max')"
else:
return compiler.visit_VARCHAR(element, **kw)
foo = Table('foo', metadata,
Column('data', VARCHAR('max'))
)
Subclassing Guidelines
======================
A big part of using the compiler extension is subclassing SQLAlchemy expression constructs. To make this easier, the expression and schema packages feature a set of "bases" intended for common tasks. A synopsis is as follows:
A big part of using the compiler extension is subclassing SQLAlchemy
expression constructs. To make this easier, the expression and
schema packages feature a set of "bases" intended for common tasks.
A synopsis is as follows:
* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root
expression class. Any SQL expression can be derived from this base, and is
probably the best choice for longer constructs such as specialized INSERT
statements.
* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all
"column-like" elements. Anything that you'd place in the "columns" clause of
a SELECT statement (as well as order by and group by) can derive from this -
the object will automatically have Python "comparison" behavior.
:class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a
``type`` member which is expression's return type. This can be established
at the instance level in the constructor, or at the class level if its
generally constant::
class timestamp(ColumnElement):
type = TIMESTAMP()
* :class:`~sqlalchemy.sql.expression.FunctionElement` - This is a hybrid of a
* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a
``ColumnElement`` and a "from clause" like object, and represents a SQL
function or stored procedure type of call. Since most databases support
statements along the line of "SELECT FROM <some function>"
``FunctionElement`` adds in the ability to be used in the FROM clause of a
``select()`` construct.
``select()`` construct::
from sqlalchemy.sql.expression import FunctionElement
class coalesce(FunctionElement):
name = 'coalesce'
@compiles(coalesce)
def compile(element, compiler, **kw):
return "coalesce(%s)" % compiler.process(element.clauses)
@compiles(coalesce, 'oracle')
def compile(element, compiler, **kw):
if len(element.clauses) > 2:
raise TypeError("coalesce only supports two arguments on Oracle")
return "nvl(%s)" % compiler.process(element.clauses)
* :class:`~sqlalchemy.schema.DDLElement` - The root of all DDL expressions,
like CREATE TABLE, ALTER TABLE, etc. Compilation of ``DDLElement``
subclasses is issued by a ``DDLCompiler`` instead of a ``SQLCompiler``.
@ -156,39 +280,195 @@ A big part of using the compiler extension is subclassing SQLAlchemy expression
``execute_at()`` method, allowing the construct to be invoked during CREATE
TABLE and DROP TABLE sequences.
* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which should be
used with any expression class that represents a "standalone" SQL statement that
can be passed directly to an ``execute()`` method. It is already implicit
within ``DDLElement`` and ``FunctionElement``.
* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which
should be used with any expression class that represents a "standalone"
SQL statement that can be passed directly to an ``execute()`` method. It
is already implicit within ``DDLElement`` and ``FunctionElement``.
Further Examples
================
"UTC timestamp" function
-------------------------
A function that works like "CURRENT_TIMESTAMP" except applies the
appropriate conversions so that the time is in UTC time. Timestamps are best
stored in relational databases as UTC, without time zones. UTC so that your
database doesn't think time has gone backwards in the hour when daylight
savings ends, without timezones because timezones are like character
encodings - they're best applied only at the endpoints of an application
(i.e. convert to UTC upon user input, re-apply desired timezone upon display).
For PostgreSQL and Microsoft SQL Server::
from sqlalchemy.sql import expression
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import DateTime
class utcnow(expression.FunctionElement):
type = DateTime()
@compiles(utcnow, 'postgresql')
def pg_utcnow(element, compiler, **kw):
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
@compiles(utcnow, 'mssql')
def ms_utcnow(element, compiler, **kw):
return "GETUTCDATE()"
Example usage::
from sqlalchemy import (
Table, Column, Integer, String, DateTime, MetaData
)
metadata = MetaData()
event = Table("event", metadata,
Column("id", Integer, primary_key=True),
Column("description", String(50), nullable=False),
Column("timestamp", DateTime, server_default=utcnow())
)
"GREATEST" function
-------------------
The "GREATEST" function is given any number of arguments and returns the one
that is of the highest value - its equivalent to Python's ``max``
function. A SQL standard version versus a CASE based version which only
accommodates two arguments::
from sqlalchemy.sql import expression
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import Numeric
class greatest(expression.FunctionElement):
type = Numeric()
name = 'greatest'
@compiles(greatest)
def default_greatest(element, compiler, **kw):
return compiler.visit_function(element)
@compiles(greatest, 'sqlite')
@compiles(greatest, 'mssql')
@compiles(greatest, 'oracle')
def case_greatest(element, compiler, **kw):
arg1, arg2 = list(element.clauses)
return "CASE WHEN %s > %s THEN %s ELSE %s END" % (
compiler.process(arg1),
compiler.process(arg2),
compiler.process(arg1),
compiler.process(arg2),
)
Example usage::
Session.query(Account).\
filter(
greatest(
Account.checking_balance,
Account.savings_balance) > 10000
)
"false" expression
------------------
Render a "false" constant expression, rendering as "0" on platforms that
don't have a "false" constant::
from sqlalchemy.sql import expression
from sqlalchemy.ext.compiler import compiles
class sql_false(expression.ColumnElement):
pass
@compiles(sql_false)
def default_false(element, compiler, **kw):
return "false"
@compiles(sql_false, 'mssql')
@compiles(sql_false, 'mysql')
@compiles(sql_false, 'oracle')
def int_false(element, compiler, **kw):
return "0"
Example usage::
from sqlalchemy import select, union_all
exp = union_all(
select([users.c.name, sql_false().label("enrolled")]),
select([customers.c.name, customers.c.enrolled])
)
"""
from .. import exc
from ..sql import visitors
def compiles(class_, *specs):
"""Register a function as a compiler for a
given :class:`.ClauseElement` type."""
def decorate(fn):
existing = getattr(class_, '_compiler_dispatcher', None)
# get an existing @compiles handler
existing = class_.__dict__.get('_compiler_dispatcher', None)
# get the original handler. All ClauseElement classes have one
# of these, but some TypeEngine classes will not.
existing_dispatch = getattr(class_, '_compiler_dispatch', None)
if not existing:
existing = _dispatcher()
if existing_dispatch:
def _wrap_existing_dispatch(element, compiler, **kw):
try:
return existing_dispatch(element, compiler, **kw)
except exc.UnsupportedCompilationError:
raise exc.CompileError(
"%s construct has no default "
"compilation handler." % type(element))
existing.specs['default'] = _wrap_existing_dispatch
# TODO: why is the lambda needed ?
setattr(class_, '_compiler_dispatch', lambda *arg, **kw: existing(*arg, **kw))
setattr(class_, '_compiler_dispatch',
lambda *arg, **kw: existing(*arg, **kw))
setattr(class_, '_compiler_dispatcher', existing)
if specs:
for s in specs:
existing.specs[s] = fn
else:
existing.specs['default'] = fn
return fn
return decorate
def deregister(class_):
"""Remove all custom compilers associated with a given
:class:`.ClauseElement` type."""
if hasattr(class_, '_compiler_dispatcher'):
# regenerate default _compiler_dispatch
visitors._generate_dispatch(class_)
# remove custom directive
del class_._compiler_dispatcher
class _dispatcher(object):
def __init__(self):
self.specs = {}
def __call__(self, element, compiler, **kw):
# TODO: yes, this could also switch off of DBAPI in use.
fn = self.specs.get(compiler.dialect.name, None)
if not fn:
fn = self.specs['default']
try:
fn = self.specs['default']
except KeyError:
raise exc.CompileError(
"%s construct has no default "
"compilation handler." % type(element))
return fn(element, compiler, **kw)

View File

@ -1,5 +1,6 @@
# horizontal_shard.py
# Copyright (C) the SQLAlchemy authors and contributors
# ext/horizontal_shard.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -9,105 +10,54 @@
Defines a rudimental 'horizontal sharding' system which allows a Session to
distribute queries and persistence operations across multiple databases.
For a usage example, see the :ref:`examples_sharding` example included in
the source distrbution.
For a usage example, see the :ref:`examples_sharding` example included in
the source distribution.
"""
import sqlalchemy.exceptions as sa_exc
from sqlalchemy import util
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.query import Query
from .. import util
from ..orm.session import Session
from ..orm.query import Query
__all__ = ['ShardedSession', 'ShardedQuery']
class ShardedSession(Session):
def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs):
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped instance, and possibly a
SQL clause, returns a shard ID. This id may be based off of the
attributes present within the object, or on some round-robin
scheme. If the scheme is based on a selection, it should set
whatever state on the instance to mark it in the future as
participating in that shard.
:param id_chooser: A callable, passed a query and a tuple of identity values, which
should return a list of shard ids where the ID might reside. The
databases will be queried in the order of this listing.
:param query_chooser: For a given Query, returns the list of shard_ids where the query
should be issued. Results from all shards returned will be combined
together into a single listing.
:param shards: A dictionary of string shard names to :class:`~sqlalchemy.engine.base.Engine`
objects.
"""
super(ShardedSession, self).__init__(**kwargs)
self.shard_chooser = shard_chooser
self.id_chooser = id_chooser
self.query_chooser = query_chooser
self.__binds = {}
self._mapper_flush_opts = {'connection_callable':self.connection}
self._query_cls = ShardedQuery
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
if shard_id is None:
shard_id = self.shard_chooser(mapper, instance)
if self.transaction is not None:
return self.transaction.connection(mapper, shard_id=shard_id)
else:
return self.get_bind(mapper,
shard_id=shard_id,
instance=instance).contextual_connect(**kwargs)
def get_bind(self, mapper, shard_id=None, instance=None, clause=None, **kw):
if shard_id is None:
shard_id = self.shard_chooser(mapper, instance, clause=clause)
return self.__binds[shard_id]
def bind_shard(self, shard_id, bind):
self.__binds[shard_id] = bind
class ShardedQuery(Query):
def __init__(self, *args, **kwargs):
super(ShardedQuery, self).__init__(*args, **kwargs)
self.id_chooser = self.session.id_chooser
self.query_chooser = self.session.query_chooser
self._shard_id = None
def set_shard(self, shard_id):
"""return a new query, limited to a single shard ID.
all subsequent operations with the returned query will
all subsequent operations with the returned query will
be against the single shard regardless of other state.
"""
q = self._clone()
q._shard_id = shard_id
return q
def _execute_and_instances(self, context):
if self._shard_id is not None:
result = self.session.connection(
mapper=self._mapper_zero(),
shard_id=self._shard_id).execute(context.statement, self._params)
def iter_for_shard(shard_id):
context.attributes['shard_id'] = shard_id
result = self._connection_from_session(
mapper=self._mapper_zero(),
shard_id=shard_id).execute(
context.statement,
self._params)
return self.instances(result, context)
if self._shard_id is not None:
return iter_for_shard(self._shard_id)
else:
partial = []
for shard_id in self.query_chooser(self):
result = self.session.connection(
mapper=self._mapper_zero(),
shard_id=shard_id).execute(context.statement, self._params)
partial = partial + list(self.instances(result, context))
# if some kind of in memory 'sorting'
partial.extend(iter_for_shard(shard_id))
# if some kind of in memory 'sorting'
# were done, this is where it would happen
return iter(partial)
@ -122,4 +72,60 @@ class ShardedQuery(Query):
return o
else:
return None
class ShardedSession(Session):
def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None,
query_cls=ShardedQuery, **kwargs):
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped
instance, and possibly a SQL clause, returns a shard ID. This id
may be based off of the attributes present within the object, or on
some round-robin scheme. If the scheme is based on a selection, it
should set whatever state on the instance to mark it in the future as
participating in that shard.
:param id_chooser: A callable, passed a query and a tuple of identity
values, which should return a list of shard ids where the ID might
reside. The databases will be queried in the order of this listing.
:param query_chooser: For a given Query, returns the list of shard_ids
where the query should be issued. Results from all shards returned
will be combined together into a single listing.
:param shards: A dictionary of string shard names
to :class:`~sqlalchemy.engine.Engine` objects.
"""
super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs)
self.shard_chooser = shard_chooser
self.id_chooser = id_chooser
self.query_chooser = query_chooser
self.__binds = {}
self.connection_callable = self.connection
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
if shard_id is None:
shard_id = self.shard_chooser(mapper, instance)
if self.transaction is not None:
return self.transaction.connection(mapper, shard_id=shard_id)
else:
return self.get_bind(
mapper,
shard_id=shard_id,
instance=instance
).contextual_connect(**kwargs)
def get_bind(self, mapper, shard_id=None,
instance=None, clause=None, **kw):
if shard_id is None:
shard_id = self.shard_chooser(mapper, instance, clause=clause)
return self.__binds[shard_id]
def bind_shard(self, shard_id, bind):
self.__binds[shard_id] = bind

View File

@ -1,58 +1,78 @@
"""A custom list that manages index/position information for its children.
# ext/orderinglist.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""A custom list that manages index/position information for contained
elements.
:author: Jason Kirtland
``orderinglist`` is a helper for mutable ordered relationships. It will intercept
list operations performed on a relationship collection and automatically
synchronize changes in list position with an attribute on the related objects.
(See :ref:`advdatamapping_entitycollections` for more information on the general pattern.)
``orderinglist`` is a helper for mutable ordered relationships. It will
intercept list operations performed on a :func:`.relationship`-managed
collection and
automatically synchronize changes in list position onto a target scalar
attribute.
Example: Two tables that store slides in a presentation. Each slide
has a number of bullet points, displayed in order by the 'position'
column on the bullets table. These bullets can be inserted and re-ordered
by your end users, and you need to update the 'position' column of all
affected rows when changes are made.
Example: A ``slide`` table, where each row refers to zero or more entries
in a related ``bullet`` table. The bullets within a slide are
displayed in order based on the value of the ``position`` column in the
``bullet`` table. As entries are reordered in memory, the value of the
``position`` attribute should be updated to reflect the new sort order::
.. sourcecode:: python+sql
slides_table = Table('Slides', metadata,
Column('id', Integer, primary_key=True),
Column('name', String))
Base = declarative_base()
bullets_table = Table('Bullets', metadata,
Column('id', Integer, primary_key=True),
Column('slide_id', Integer, ForeignKey('Slides.id')),
Column('position', Integer),
Column('text', String))
class Slide(Base):
__tablename__ = 'slide'
class Slide(object):
pass
class Bullet(object):
pass
id = Column(Integer, primary_key=True)
name = Column(String)
mapper(Slide, slides_table, properties={
'bullets': relationship(Bullet, order_by=[bullets_table.c.position])
})
mapper(Bullet, bullets_table)
bullets = relationship("Bullet", order_by="Bullet.position")
The standard relationship mapping will produce a list-like attribute on each Slide
containing all related Bullets, but coping with changes in ordering is totally
your responsibility. If you insert a Bullet into that list, there is no
magic- it won't have a position attribute unless you assign it it one, and
you'll need to manually renumber all the subsequent Bullets in the list to
accommodate the insert.
class Bullet(Base):
__tablename__ = 'bullet'
id = Column(Integer, primary_key=True)
slide_id = Column(Integer, ForeignKey('slide.id'))
position = Column(Integer)
text = Column(String)
An ``orderinglist`` can automate this and manage the 'position' attribute on all
related bullets for you.
The standard relationship mapping will produce a list-like attribute on each
``Slide`` containing all related ``Bullet`` objects,
but coping with changes in ordering is not handled automatically.
When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position``
attribute will remain unset until manually assigned. When the ``Bullet``
is inserted into the middle of the list, the following ``Bullet`` objects
will also need to be renumbered.
.. sourcecode:: python+sql
mapper(Slide, slides_table, properties={
'bullets': relationship(Bullet,
collection_class=ordering_list('position'),
order_by=[bullets_table.c.position])
})
mapper(Bullet, bullets_table)
The :class:`.OrderingList` object automates this task, managing the
``position`` attribute on all ``Bullet`` objects in the collection. It is
constructed using the :func:`.ordering_list` factory::
from sqlalchemy.ext.orderinglist import ordering_list
Base = declarative_base()
class Slide(Base):
__tablename__ = 'slide'
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship("Bullet", order_by="Bullet.position",
collection_class=ordering_list('position'))
class Bullet(Base):
__tablename__ = 'bullet'
id = Column(Integer, primary_key=True)
slide_id = Column(Integer, ForeignKey('slide.id'))
position = Column(Integer)
text = Column(String)
With the above mapping the ``Bullet.position`` attribute is managed::
s = Slide()
s.bullets.append(Bullet())
@ -63,71 +83,98 @@ related bullets for you.
s.bullets[2].position
>>> 2
Use the ``ordering_list`` function to set up the ``collection_class`` on relationships
(as in the mapper example above). This implementation depends on the list
starting in the proper order, so be SURE to put an order_by on your relationship.
The :class:`.OrderingList` construct only works with **changes** to a
collection, and not the initial load from the database, and requires that the
list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the
:func:`.relationship` against the target ordering attribute, so that the
ordering is correct when first loaded.
.. warning:: ``ordering_list`` only provides limited functionality when a primary
key column or unique column is the target of the sort. Since changing the order of
entries often means that two rows must trade values, this is not possible when
the value is constrained by a primary key or unique constraint, since one of the rows
would temporarily have to point to a third available value so that the other row
could take its old value. ``ordering_list`` doesn't do any of this for you,
nor does SQLAlchemy itself.
.. warning::
``ordering_list`` takes the name of the related object's ordering attribute as
an argument. By default, the zero-based integer index of the object's
position in the ``ordering_list`` is synchronized with the ordering attribute:
index 0 will get position 0, index 1 position 1, etc. To start numbering at 1
or some other integer, provide ``count_from=1``.
:class:`.OrderingList` only provides limited functionality when a primary
key column or unique column is the target of the sort. Operations
that are unsupported or are problematic include:
Ordering values are not limited to incrementing integers. Almost any scheme
can implemented by supplying a custom ``ordering_func`` that maps a Python list
index to any value you require.
* two entries must trade values. This is not supported directly in the
case of a primary key or unique constraint because it means at least
one row would need to be temporarily removed first, or changed to
a third, neutral value while the switch occurs.
* an entry must be deleted in order to make room for a new entry.
SQLAlchemy's unit of work performs all INSERTs before DELETEs within a
single flush. In the case of a primary key, it will trade
an INSERT/DELETE of the same primary key for an UPDATE statement in order
to lessen the impact of this limitation, however this does not take place
for a UNIQUE column.
A future feature will allow the "DELETE before INSERT" behavior to be
possible, allevating this limitation, though this feature will require
explicit configuration at the mapper level for sets of columns that
are to be handled in this way.
:func:`.ordering_list` takes the name of the related object's ordering
attribute as an argument. By default, the zero-based integer index of the
object's position in the :func:`.ordering_list` is synchronized with the
ordering attribute: index 0 will get position 0, index 1 position 1, etc. To
start numbering at 1 or some other integer, provide ``count_from=1``.
"""
from sqlalchemy.orm.collections import collection
from sqlalchemy import util
from ..orm.collections import collection, collection_adapter
from .. import util
__all__ = [ 'ordering_list' ]
__all__ = ['ordering_list']
def ordering_list(attr, count_from=None, **kw):
"""Prepares an OrderingList factory for use in mapper definitions.
"""Prepares an :class:`OrderingList` factory for use in mapper definitions.
Returns an object suitable for use as an argument to a Mapper relationship's
``collection_class`` option. Arguments are:
Returns an object suitable for use as an argument to a Mapper
relationship's ``collection_class`` option. e.g.::
attr
from sqlalchemy.ext.orderinglist import ordering_list
class Slide(Base):
__tablename__ = 'slide'
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship("Bullet", order_by="Bullet.position",
collection_class=ordering_list('position'))
:param attr:
Name of the mapped attribute to use for storage and retrieval of
ordering information
count_from (optional)
:param count_from:
Set up an integer-based ordering, starting at ``count_from``. For
example, ``ordering_list('pos', count_from=1)`` would create a 1-based
list in SQL, storing the value in the 'pos' column. Ignored if
``ordering_func`` is supplied.
Passes along any keyword arguments to ``OrderingList`` constructor.
Additional arguments are passed to the :class:`.OrderingList` constructor.
"""
kw = _unsugar_count_from(count_from=count_from, **kw)
return lambda: OrderingList(attr, **kw)
# Ordering utility functions
def count_from_0(index, collection):
"""Numbering function: consecutive integers starting at 0."""
return index
def count_from_1(index, collection):
"""Numbering function: consecutive integers starting at 1."""
return index + 1
def count_from_n_factory(start):
"""Numbering function: consecutive integers starting at arbitrary start."""
@ -139,8 +186,9 @@ def count_from_n_factory(start):
pass
return f
def _unsugar_count_from(**kw):
"""Builds counting functions from keywrod arguments.
"""Builds counting functions from keyword arguments.
Keyword argument filter, prepares a simple ``ordering_func`` from a
``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
@ -156,12 +204,13 @@ def _unsugar_count_from(**kw):
kw['ordering_func'] = count_from_n_factory(count_from)
return kw
class OrderingList(list):
"""A custom list that manages position information for its children.
See the module and __init__ documentation for more details. The
``ordering_list`` factory function is used to configure ``OrderingList``
collections in ``mapper`` relationship definitions.
The :class:`.OrderingList` object is normally set up using the
:func:`.ordering_list` factory function, used in conjunction with
the :func:`.relationship` function.
"""
@ -176,14 +225,14 @@ class OrderingList(list):
This implementation relies on the list starting in the proper order,
so be **sure** to put an ``order_by`` on your relationship.
ordering_attr
:param ordering_attr:
Name of the attribute that stores the object's order in the
relationship.
ordering_func
Optional. A function that maps the position in the Python list to a
value to store in the ``ordering_attr``. Values returned are
usually (but need not be!) integers.
:param ordering_func: Optional. A function that maps the position in
the Python list to a value to store in the
``ordering_attr``. Values returned are usually (but need not be!)
integers.
An ``ordering_func`` is called with two positional parameters: the
index of the element in the list, and the list itself.
@ -194,7 +243,7 @@ class OrderingList(list):
like stepped numbering, alphabetical and Fibonacci numbering, see
the unit tests.
reorder_on_append
:param reorder_on_append:
Default False. When appending an object with an existing (non-None)
ordering value, that value will be left untouched unless
``reorder_on_append`` is true. This is an optimization to avoid a
@ -208,7 +257,7 @@ class OrderingList(list):
making changes, any of whom happen to load this collection even in
passing, all of the sessions would try to "clean up" the numbering
in their commits, possibly causing all but one to fail with a
concurrent modification error. Spooky action at a distance.
concurrent modification error.
Recommend leaving this with the default of False, and just call
``reorder()`` if you're doing ``append()`` operations with
@ -270,7 +319,10 @@ class OrderingList(list):
def remove(self, entity):
super(OrderingList, self).remove(entity)
self._reorder()
adapter = collection_adapter(self)
if adapter and adapter._referenced_by_owner:
self._reorder()
def pop(self, index=-1):
entity = super(OrderingList, self).pop(index)
@ -286,8 +338,8 @@ class OrderingList(list):
stop = index.stop or len(self)
if stop < 0:
stop += len(self)
for i in xrange(start, stop, step):
for i in range(start, stop, step):
self.__setitem__(i, entity[i])
else:
self._order_entity(index, entity, True)
@ -297,7 +349,6 @@ class OrderingList(list):
super(OrderingList, self).__delitem__(index)
self._reorder()
# Py2K
def __setslice__(self, start, end, values):
super(OrderingList, self).__setslice__(start, end, values)
self._reorder()
@ -305,11 +356,25 @@ class OrderingList(list):
def __delslice__(self, start, end):
super(OrderingList, self).__delslice__(start, end)
self._reorder()
# end Py2K
for func_name, func in locals().items():
if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(list, func_name)):
def __reduce__(self):
return _reconstitute, (self.__class__, self.__dict__, list(self))
for func_name, func in list(locals().items()):
if (util.callable(func) and func.__name__ == func_name and
not func.__doc__ and hasattr(list, func_name)):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
def _reconstitute(cls, dict_, items):
""" Reconstitute an :class:`.OrderingList`.
This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for
unpickling :class:`.OrderingList` objects.
"""
obj = cls.__new__(cls)
obj.__dict__.update(dict_)
list.extend(obj, items)
return obj

View File

@ -1,4 +1,11 @@
"""Serializer/Deserializer objects for usage with SQLAlchemy query structures,
# ext/serializer.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Serializer/Deserializer objects for usage with SQLAlchemy query structures,
allowing "contextual" deserialization.
Any SQLAlchemy query structure, either based on sqlalchemy.sql.*
@ -12,82 +19,73 @@ Usage is nearly the same as that of the standard Python pickle module::
from sqlalchemy.ext.serializer import loads, dumps
metadata = MetaData(bind=some_engine)
Session = scoped_session(sessionmaker())
# ... define mappers
query = Session.query(MyClass).filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
query = Session.query(MyClass).
filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
# pickle the query
serialized = dumps(query)
# unpickle. Pass in metadata + scoped_session
query2 = loads(serialized, metadata, Session)
print query2.all()
Similar restrictions as when using raw pickle apply; mapped classes must be
Similar restrictions as when using raw pickle apply; mapped classes must be
themselves be pickleable, meaning they are importable from a module-level
namespace.
The serializer module is only appropriate for query structures. It is not
needed for:
* instances of user-defined classes. These contain no references to engines,
sessions or expression constructs in the typical case and can be serialized directly.
* instances of user-defined classes. These contain no references to engines,
sessions or expression constructs in the typical case and can be serialized
directly.
* Table metadata that is to be loaded entirely from the serialized structure (i.e. is
not already declared in the application). Regular pickle.loads()/dumps() can
be used to fully dump any ``MetaData`` object, typically one which was reflected
from an existing database at some previous point in time. The serializer module
is specifically for the opposite case, where the Table metadata is already present
in memory.
* Table metadata that is to be loaded entirely from the serialized structure
(i.e. is not already declared in the application). Regular
pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object,
typically one which was reflected from an existing database at some previous
point in time. The serializer module is specifically for the opposite case,
where the Table metadata is already present in memory.
"""
from sqlalchemy.orm import class_mapper, Query
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.attributes import QueryableAttribute
from sqlalchemy import Table, Column
from sqlalchemy.engine import Engine
from sqlalchemy.util import pickle
from ..orm import class_mapper
from ..orm.session import Session
from ..orm.mapper import Mapper
from ..orm.interfaces import MapperProperty
from ..orm.attributes import QueryableAttribute
from .. import Table, Column
from ..engine import Engine
from ..util import pickle, byte_buffer, b64encode, b64decode, text_type
import re
import base64
# Py3K
#from io import BytesIO as byte_buffer
# Py2K
from cStringIO import StringIO as byte_buffer
# end Py2K
# Py3K
#def b64encode(x):
# return base64.b64encode(x).decode('ascii')
#def b64decode(x):
# return base64.b64decode(x.encode('ascii'))
# Py2K
b64encode = base64.b64encode
b64decode = base64.b64decode
# end Py2K
__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads']
def Serializer(*args, **kw):
pickler = pickle.Pickler(*args, **kw)
def persistent_id(obj):
#print "serializing:", repr(obj)
# print "serializing:", repr(obj)
if isinstance(obj, QueryableAttribute):
cls = obj.impl.class_
key = obj.impl.key
id = "attribute:" + key + ":" + b64encode(pickle.dumps(cls))
elif isinstance(obj, Mapper) and not obj.non_primary:
id = "mapper:" + b64encode(pickle.dumps(obj.class_))
elif isinstance(obj, MapperProperty) and not obj.parent.non_primary:
id = "mapperprop:" + b64encode(pickle.dumps(obj.parent.class_)) + \
":" + obj.key
elif isinstance(obj, Table):
id = "table:" + str(obj)
id = "table:" + text_type(obj.key)
elif isinstance(obj, Column) and isinstance(obj.table, Table):
id = "column:" + str(obj.table) + ":" + obj.key
id = "column:" + \
text_type(obj.table.key) + ":" + text_type(obj.key)
elif isinstance(obj, Session):
id = "session:"
elif isinstance(obj, Engine):
@ -95,15 +93,17 @@ def Serializer(*args, **kw):
else:
return None
return id
pickler.persistent_id = persistent_id
return pickler
our_ids = re.compile(r'(mapper|table|column|session|attribute|engine):(.*)')
our_ids = re.compile(
r'(mapperprop|mapper|table|column|session|attribute|engine):(.*)')
def Deserializer(file, metadata=None, scoped_session=None, engine=None):
unpickler = pickle.Unpickler(file)
def get_engine():
if engine:
return engine
@ -113,9 +113,9 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
return metadata.bind
else:
return None
def persistent_load(id):
m = our_ids.match(id)
m = our_ids.match(text_type(id))
if not m:
return None
else:
@ -127,6 +127,10 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
elif type_ == "mapper":
cls = pickle.loads(b64decode(args))
return class_mapper(cls)
elif type_ == "mapperprop":
mapper, keyname = args.split(':')
cls = pickle.loads(b64decode(mapper))
return class_mapper(cls).attrs[keyname]
elif type_ == "table":
return metadata.tables[args]
elif type_ == "column":
@ -141,15 +145,15 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
unpickler.persistent_load = persistent_load
return unpickler
def dumps(obj, protocol=0):
buf = byte_buffer()
pickler = Serializer(buf, protocol)
pickler.dump(obj)
return buf.getvalue()
def loads(data, metadata=None, scoped_session=None, engine=None):
buf = byte_buffer(data)
unpickler = Deserializer(buf, metadata, scoped_session, engine)
return unpickler.load()

View File

@ -1,31 +1,45 @@
# interfaces.py
# sqlalchemy/interfaces.py
# Copyright (C) 2007-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# Copyright (C) 2007 Jason Kirtland jek@discorporate.us
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Interfaces and abstract types."""
"""Deprecated core event interfaces.
This module is **deprecated** and is superseded by the
event system.
"""
from . import event, util
class PoolListener(object):
"""Hooks into the lifecycle of connections in a ``Pool``.
"""Hooks into the lifecycle of connections in a :class:`.Pool`.
.. note::
:class:`.PoolListener` is deprecated. Please
refer to :class:`.PoolEvents`.
Usage::
class MyListener(PoolListener):
def connect(self, dbapi_con, con_record):
'''perform connect operations'''
# etc.
# etc.
# create a new pool with a listener
p = QueuePool(..., listeners=[MyListener()])
# add a listener after the fact
p.add_listener(MyListener())
# usage with create_engine()
e = create_engine("url://", listeners=[MyListener()])
All of the standard connection :class:`~sqlalchemy.pool.Pool` types can
accept event listeners for key connection lifecycle events:
creation, pool check-out and check-in. There are no events fired
@ -56,9 +70,28 @@ class PoolListener(object):
internal event queues based on its capabilities. In terms of
efficiency and function call overhead, you're much better off only
providing implementations for the hooks you'll be using.
"""
@classmethod
def _adapt_listener(cls, self, listener):
"""Adapt a :class:`.PoolListener` to individual
:class:`event.Dispatch` events.
"""
listener = util.as_interface(listener,
methods=('connect', 'first_connect',
'checkout', 'checkin'))
if hasattr(listener, 'connect'):
event.listen(self, 'connect', listener.connect)
if hasattr(listener, 'first_connect'):
event.listen(self, 'first_connect', listener.first_connect)
if hasattr(listener, 'checkout'):
event.listen(self, 'checkout', listener.checkout)
if hasattr(listener, 'checkin'):
event.listen(self, 'checkin', listener.checkin)
def connect(self, dbapi_con, con_record):
"""Called once for each new DB-API connection or Pool's ``creator()``.
@ -117,89 +150,163 @@ class PoolListener(object):
"""
class ConnectionProxy(object):
"""Allows interception of statement execution by Connections.
.. note::
:class:`.ConnectionProxy` is deprecated. Please
refer to :class:`.ConnectionEvents`.
Either or both of the ``execute()`` and ``cursor_execute()``
may be implemented to intercept compiled statement and
cursor level executions, e.g.::
class MyProxy(ConnectionProxy):
def execute(self, conn, execute, clauseelement, *multiparams, **params):
def execute(self, conn, execute, clauseelement,
*multiparams, **params):
print "compiled statement:", clauseelement
return execute(clauseelement, *multiparams, **params)
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
def cursor_execute(self, execute, cursor, statement,
parameters, context, executemany):
print "raw statement:", statement
return execute(cursor, statement, parameters, context)
The ``execute`` argument is a function that will fulfill the default
execution behavior for the operation. The signature illustrated
in the example should be used.
The proxy is installed into an :class:`~sqlalchemy.engine.Engine` via
the ``proxy`` argument::
e = create_engine('someurl://', proxy=MyProxy())
"""
@classmethod
def _adapt_listener(cls, self, listener):
def adapt_execute(conn, clauseelement, multiparams, params):
def execute_wrapper(clauseelement, *multiparams, **params):
return clauseelement, multiparams, params
return listener.execute(conn, execute_wrapper,
clauseelement, *multiparams,
**params)
event.listen(self, 'before_execute', adapt_execute)
def adapt_cursor_execute(conn, cursor, statement,
parameters, context, executemany):
def execute_wrapper(
cursor,
statement,
parameters,
context,
):
return statement, parameters
return listener.cursor_execute(
execute_wrapper,
cursor,
statement,
parameters,
context,
executemany,
)
event.listen(self, 'before_cursor_execute', adapt_cursor_execute)
def do_nothing_callback(*arg, **kw):
pass
def adapt_listener(fn):
def go(conn, *arg, **kw):
fn(conn, do_nothing_callback, *arg, **kw)
return util.update_wrapper(go, fn)
event.listen(self, 'begin', adapt_listener(listener.begin))
event.listen(self, 'rollback',
adapt_listener(listener.rollback))
event.listen(self, 'commit', adapt_listener(listener.commit))
event.listen(self, 'savepoint',
adapt_listener(listener.savepoint))
event.listen(self, 'rollback_savepoint',
adapt_listener(listener.rollback_savepoint))
event.listen(self, 'release_savepoint',
adapt_listener(listener.release_savepoint))
event.listen(self, 'begin_twophase',
adapt_listener(listener.begin_twophase))
event.listen(self, 'prepare_twophase',
adapt_listener(listener.prepare_twophase))
event.listen(self, 'rollback_twophase',
adapt_listener(listener.rollback_twophase))
event.listen(self, 'commit_twophase',
adapt_listener(listener.commit_twophase))
def execute(self, conn, execute, clauseelement, *multiparams, **params):
"""Intercept high level execute() events."""
return execute(clauseelement, *multiparams, **params)
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
def cursor_execute(self, execute, cursor, statement, parameters,
context, executemany):
"""Intercept low-level cursor execute() events."""
return execute(cursor, statement, parameters, context)
def begin(self, conn, begin):
"""Intercept begin() events."""
return begin()
def rollback(self, conn, rollback):
"""Intercept rollback() events."""
return rollback()
def commit(self, conn, commit):
"""Intercept commit() events."""
return commit()
def savepoint(self, conn, savepoint, name=None):
"""Intercept savepoint() events."""
return savepoint(name=name)
def rollback_savepoint(self, conn, rollback_savepoint, name, context):
"""Intercept rollback_savepoint() events."""
return rollback_savepoint(name, context)
def release_savepoint(self, conn, release_savepoint, name, context):
"""Intercept release_savepoint() events."""
return release_savepoint(name, context)
def begin_twophase(self, conn, begin_twophase, xid):
"""Intercept begin_twophase() events."""
return begin_twophase(xid)
def prepare_twophase(self, conn, prepare_twophase, xid):
"""Intercept prepare_twophase() events."""
return prepare_twophase(xid)
def rollback_twophase(self, conn, rollback_twophase, xid, is_prepared):
"""Intercept rollback_twophase() events."""
return rollback_twophase(xid, is_prepared)
def commit_twophase(self, conn, commit_twophase, xid, is_prepared):
"""Intercept commit_twophase() events."""
return commit_twophase(xid, is_prepared)

View File

@ -1,5 +1,7 @@
# log.py - adapt python logging module to SQLAlchemy
# Copyright (C) 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
# sqlalchemy/log.py
# Copyright (C) 2006-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# Includes alterations by Vinay Sajip vinay_sajip@yahoo.co.uk
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -10,92 +12,189 @@ Control of logging for SA can be performed from the regular python logging
module. The regular dotted module namespace is used, starting at
'sqlalchemy'. For class-level logging, the class name is appended.
The "echo" keyword parameter which is available on SQLA ``Engine``
and ``Pool`` objects corresponds to a logger specific to that
The "echo" keyword parameter, available on SQLA :class:`.Engine`
and :class:`.Pool` objects, corresponds to a logger specific to that
instance only.
E.g.::
engine.echo = True
is equivalent to::
import logging
logger = logging.getLogger('sqlalchemy.engine.Engine.%s' % hex(id(engine)))
logger.setLevel(logging.DEBUG)
"""
import logging
import sys
from sqlalchemy import util
# set initial level to WARN. This so that
# log statements don't occur in the absence of explicit
# logging being enabled for 'sqlalchemy'.
rootlogger = logging.getLogger('sqlalchemy')
if rootlogger.level == logging.NOTSET:
rootlogger.setLevel(logging.WARN)
default_enabled = False
def default_logging(name):
global default_enabled
if logging.getLogger(name).getEffectiveLevel() < logging.WARN:
default_enabled = True
if not default_enabled:
default_enabled = True
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s %(name)s %(message)s'))
rootlogger.addHandler(handler)
def _add_default_handler(logger):
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s %(name)s %(message)s'))
logger.addHandler(handler)
_logged_classes = set()
def class_logger(cls, enable=False):
def class_logger(cls):
logger = logging.getLogger(cls.__module__ + "." + cls.__name__)
if enable == 'debug':
logger.setLevel(logging.DEBUG)
elif enable == 'info':
logger.setLevel(logging.INFO)
cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG)
cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO)
cls.logger = logger
_logged_classes.add(cls)
return cls
class Identified(object):
@util.memoized_property
def logging_name(self):
# limit the number of loggers by chopping off the hex(id).
# some novice users unfortunately create an unlimited number
# of Engines in their applications which would otherwise
# cause the app to run out of memory.
return "0x...%s" % hex(id(self))[-4:]
logging_name = None
def instance_logger(instance, echoflag=None):
"""create a logger for an instance that implements :class:`Identified`.
Warning: this is an expensive call which also results in a permanent
increase in memory overhead for each call. Use only for
low-volume, long-time-spanning objects.
def _should_log_debug(self):
return self.logger.isEnabledFor(logging.DEBUG)
def _should_log_info(self):
return self.logger.isEnabledFor(logging.INFO)
class InstanceLogger(object):
"""A logger adapter (wrapper) for :class:`.Identified` subclasses.
This allows multiple instances (e.g. Engine or Pool instances)
to share a logger, but have its verbosity controlled on a
per-instance basis.
The basic functionality is to return a logging level
which is based on an instance's echo setting.
Default implementation is:
'debug' -> logging.DEBUG
True -> logging.INFO
False -> Effective level of underlying logger
(logging.WARNING by default)
None -> same as False
"""
name = "%s.%s.%s" % (instance.__class__.__module__,
instance.__class__.__name__, instance.logging_name)
if echoflag is not None:
l = logging.getLogger(name)
if echoflag == 'debug':
default_logging(name)
l.setLevel(logging.DEBUG)
elif echoflag is True:
default_logging(name)
l.setLevel(logging.INFO)
elif echoflag is False:
l.setLevel(logging.WARN)
# Map echo settings to logger levels
_echo_map = {
None: logging.NOTSET,
False: logging.NOTSET,
True: logging.INFO,
'debug': logging.DEBUG,
}
def __init__(self, echo, name):
self.echo = echo
self.logger = logging.getLogger(name)
# if echo flag is enabled and no handlers,
# add a handler to the list
if self._echo_map[echo] <= logging.INFO \
and not self.logger.handlers:
_add_default_handler(self.logger)
#
# Boilerplate convenience methods
#
def debug(self, msg, *args, **kwargs):
"""Delegate a debug call to the underlying logger."""
self.log(logging.DEBUG, msg, *args, **kwargs)
def info(self, msg, *args, **kwargs):
"""Delegate an info call to the underlying logger."""
self.log(logging.INFO, msg, *args, **kwargs)
def warning(self, msg, *args, **kwargs):
"""Delegate a warning call to the underlying logger."""
self.log(logging.WARNING, msg, *args, **kwargs)
warn = warning
def error(self, msg, *args, **kwargs):
"""
Delegate an error call to the underlying logger.
"""
self.log(logging.ERROR, msg, *args, **kwargs)
def exception(self, msg, *args, **kwargs):
"""Delegate an exception call to the underlying logger."""
kwargs["exc_info"] = 1
self.log(logging.ERROR, msg, *args, **kwargs)
def critical(self, msg, *args, **kwargs):
"""Delegate a critical call to the underlying logger."""
self.log(logging.CRITICAL, msg, *args, **kwargs)
def log(self, level, msg, *args, **kwargs):
"""Delegate a log call to the underlying logger.
The level here is determined by the echo
flag as well as that of the underlying logger, and
logger._log() is called directly.
"""
# inline the logic from isEnabledFor(),
# getEffectiveLevel(), to avoid overhead.
if self.logger.manager.disable >= level:
return
selected_level = self._echo_map[self.echo]
if selected_level == logging.NOTSET:
selected_level = self.logger.getEffectiveLevel()
if level >= selected_level:
self.logger._log(level, msg, args, **kwargs)
def isEnabledFor(self, level):
"""Is this logger enabled for level 'level'?"""
if self.logger.manager.disable >= level:
return False
return level >= self.getEffectiveLevel()
def getEffectiveLevel(self):
"""What's the effective level for this logger?"""
level = self._echo_map[self.echo]
if level == logging.NOTSET:
level = self.logger.getEffectiveLevel()
return level
def instance_logger(instance, echoflag=None):
"""create a logger for an instance that implements :class:`.Identified`."""
if instance.logging_name:
name = "%s.%s.%s" % (instance.__class__.__module__,
instance.__class__.__name__,
instance.logging_name)
else:
l = logging.getLogger(name)
instance._should_log_debug = lambda: l.isEnabledFor(logging.DEBUG)
instance._should_log_info = lambda: l.isEnabledFor(logging.INFO)
return l
name = "%s.%s" % (instance.__class__.__module__,
instance.__class__.__name__)
instance._echo = echoflag
if echoflag in (False, None):
# if no echo setting or False, return a Logger directly,
# avoiding overhead of filtering
logger = logging.getLogger(name)
else:
# if a specified echo flag, return an EchoLogger,
# which checks the flag, overrides normal log
# levels by calling logger._log()
logger = InstanceLogger(echoflag, name)
instance.logger = logger
class echo_property(object):
__doc__ = """\
@ -112,8 +211,7 @@ class echo_property(object):
if instance is None:
return self
else:
return instance._should_log_debug() and 'debug' or \
(instance._should_log_info() and True or False)
return instance._echo
def __set__(self, instance, value):
instance_logger(instance, echoflag=value)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
# dynamic.py
# Copyright (C) the SQLAlchemy authors and contributors
# orm/dynamic.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -11,42 +12,47 @@ basic add/delete mutation.
"""
from sqlalchemy import log, util
from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import exc as sa_exc
from sqlalchemy.sql import operators
from sqlalchemy.orm import (
attributes, object_session, util as mapperutil, strategies, object_mapper
)
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.util import _state_has_identity, has_identity
from sqlalchemy.orm import attributes, collections
from .. import log, util, exc
from ..sql import operators
from . import (
attributes, object_session, util as orm_util, strategies,
object_mapper, exc as orm_exc, properties
)
from .query import Query
@log.class_logger
@properties.RelationshipProperty.strategy_for(lazy="dynamic")
class DynaLoader(strategies.AbstractRelationshipLoader):
def init_class_attribute(self, mapper):
self.is_class_level = True
strategies._register_attribute(self,
if not self.uselist:
raise exc.InvalidRequestError(
"On relationship %s, 'dynamic' loaders cannot be used with "
"many-to-one/one-to-one relationships and/or "
"uselist=False." % self.parent_property)
strategies._register_attribute(
self.parent_property,
mapper,
useobject=True,
impl_class=DynamicAttributeImpl,
target_mapper=self.parent_property.mapper,
order_by=self.parent_property.order_by,
query_class=self.parent_property.query_class
query_class=self.parent_property.query_class,
)
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
return (None, None)
log.class_logger(DynaLoader)
class DynamicAttributeImpl(attributes.AttributeImpl):
uses_objects = True
accepts_scalar_loader = False
supports_population = False
collection = False
def __init__(self, class_, key, typecallable,
target_mapper, order_by, query_class=None, **kwargs):
super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, **kwargs)
dispatch,
target_mapper, order_by, query_class=None, **kw):
super(DynamicAttributeImpl, self).\
__init__(class_, key, typecallable, dispatch, **kw)
self.target_mapper = target_mapper
self.order_by = order_by
if not query_class:
@ -56,178 +62,204 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
else:
self.query_class = mixin_user_query(query_class)
def get(self, state, dict_, passive=False):
if passive:
return self._get_collection_history(state, passive=True).added_items
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
if not passive & attributes.SQL_OK:
return self._get_collection_history(
state, attributes.PASSIVE_NO_INITIALIZE).added_items
else:
return self.query_class(self, state)
def get_collection(self, state, dict_, user_data=None, passive=True):
if passive:
return self._get_collection_history(state, passive=passive).added_items
def get_collection(self, state, dict_, user_data=None,
passive=attributes.PASSIVE_NO_INITIALIZE):
if not passive & attributes.SQL_OK:
return self._get_collection_history(state,
passive).added_items
else:
history = self._get_collection_history(state, passive=passive)
return history.added_items + history.unchanged_items
history = self._get_collection_history(state, passive)
return history.added_plus_unchanged
def fire_append_event(self, state, dict_, value, initiator):
collection_history = self._modified_event(state, dict_)
collection_history.added_items.append(value)
@util.memoized_property
def _append_token(self):
return attributes.Event(self, attributes.OP_APPEND)
for ext in self.extensions:
ext.append(state, value, initiator or self)
@util.memoized_property
def _remove_token(self):
return attributes.Event(self, attributes.OP_REMOVE)
def fire_append_event(self, state, dict_, value, initiator,
collection_history=None):
if collection_history is None:
collection_history = self._modified_event(state, dict_)
collection_history.add_added(value)
for fn in self.dispatch.append:
value = fn(state, value, initiator or self._append_token)
if self.trackparent and value is not None:
self.sethasparent(attributes.instance_state(value), True)
self.sethasparent(attributes.instance_state(value), state, True)
def fire_remove_event(self, state, dict_, value, initiator):
collection_history = self._modified_event(state, dict_)
collection_history.deleted_items.append(value)
def fire_remove_event(self, state, dict_, value, initiator,
collection_history=None):
if collection_history is None:
collection_history = self._modified_event(state, dict_)
collection_history.add_removed(value)
if self.trackparent and value is not None:
self.sethasparent(attributes.instance_state(value), False)
self.sethasparent(attributes.instance_state(value), state, False)
for ext in self.extensions:
ext.remove(state, value, initiator or self)
for fn in self.dispatch.remove:
fn(state, value, initiator or self._remove_token)
def _modified_event(self, state, dict_):
if self.key not in state.committed_state:
state.committed_state[self.key] = CollectionHistory(self, state)
state.modified_event(dict_,
self,
False,
attributes.NEVER_SET,
passive=attributes.PASSIVE_NO_INITIALIZE)
state._modified_event(dict_,
self,
attributes.NEVER_SET)
# this is a hack to allow the _base.ComparableEntity fixture
# this is a hack to allow the fixtures.ComparableEntity fixture
# to work
dict_[self.key] = True
return state.committed_state[self.key]
def set(self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF):
if initiator is self:
def set(self, state, dict_, value, initiator=None,
passive=attributes.PASSIVE_OFF,
check_old=None, pop=False, _adapt=True):
if initiator and initiator.parent_token is self.parent_token:
return
self._set_iterable(state, dict_, value)
if pop and value is None:
return
def _set_iterable(self, state, dict_, iterable, adapter=None):
iterable = value
new_values = list(iterable)
if state.has_identity:
old_collection = util.IdentitySet(self.get(state, dict_))
collection_history = self._modified_event(state, dict_)
new_values = list(iterable)
if _state_has_identity(state):
old_collection = list(self.get(state, dict_))
if not state.has_identity:
old_collection = collection_history.added_items
else:
old_collection = []
old_collection = old_collection.union(
collection_history.added_items)
collections.bulk_replace(new_values, DynCollectionAdapter(self, state, old_collection), DynCollectionAdapter(self, state, new_values))
idset = util.IdentitySet
constants = old_collection.intersection(new_values)
additions = idset(new_values).difference(constants)
removals = old_collection.difference(constants)
for member in new_values:
if member in additions:
self.fire_append_event(state, dict_, member, None,
collection_history=collection_history)
for member in removals:
self.fire_remove_event(state, dict_, member, None,
collection_history=collection_history)
def delete(self, *args, **kwargs):
raise NotImplementedError()
def get_history(self, state, dict_, passive=False):
c = self._get_collection_history(state, passive)
return attributes.History(c.added_items, c.unchanged_items, c.deleted_items)
def set_committed_value(self, state, dict_, value):
raise NotImplementedError("Dynamic attributes don't support "
"collection population.")
def _get_collection_history(self, state, passive=False):
def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
c = self._get_collection_history(state, passive)
return c.as_history()
def get_all_pending(self, state, dict_,
passive=attributes.PASSIVE_NO_INITIALIZE):
c = self._get_collection_history(
state, passive)
return [
(attributes.instance_state(x), x)
for x in
c.all_items
]
def _get_collection_history(self, state, passive=attributes.PASSIVE_OFF):
if self.key in state.committed_state:
c = state.committed_state[self.key]
else:
c = CollectionHistory(self, state)
if not passive:
if state.has_identity and (passive & attributes.INIT_OK):
return CollectionHistory(self, state, apply_to=c)
else:
return c
def append(self, state, dict_, value, initiator, passive=False):
def append(self, state, dict_, value, initiator,
passive=attributes.PASSIVE_OFF):
if initiator is not self:
self.fire_append_event(state, dict_, value, initiator)
def remove(self, state, dict_, value, initiator, passive=False):
def remove(self, state, dict_, value, initiator,
passive=attributes.PASSIVE_OFF):
if initiator is not self:
self.fire_remove_event(state, dict_, value, initiator)
class DynCollectionAdapter(object):
"""the dynamic analogue to orm.collections.CollectionAdapter"""
def pop(self, state, dict_, value, initiator,
passive=attributes.PASSIVE_OFF):
self.remove(state, dict_, value, initiator, passive=passive)
def __init__(self, attr, owner_state, data):
self.attr = attr
self.state = owner_state
self.data = data
def __iter__(self):
return iter(self.data)
def append_with_event(self, item, initiator=None):
self.attr.append(self.state, self.state.dict, item, initiator)
def remove_with_event(self, item, initiator=None):
self.attr.remove(self.state, self.state.dict, item, initiator)
def append_without_event(self, item):
pass
def remove_without_event(self, item):
pass
class AppenderMixin(object):
query_class = None
def __init__(self, attr, state):
Query.__init__(self, attr.target_mapper, None)
super(AppenderMixin, self).__init__(attr.target_mapper, None)
self.instance = instance = state.obj()
self.attr = attr
mapper = object_mapper(instance)
prop = mapper.get_property(self.attr.key, resolve_synonyms=True)
self._criterion = prop.compare(
operators.eq,
instance,
value_is_parent=True,
alias_secondary=False)
prop = mapper._props[self.attr.key]
self._criterion = prop._with_parent(
instance,
alias_secondary=False)
if self.attr.order_by:
self._order_by = self.attr.order_by
def __session(self):
def session(self):
sess = object_session(self.instance)
if sess is not None and self.autoflush and sess.autoflush and self.instance in sess:
if sess is not None and self.autoflush and sess.autoflush \
and self.instance in sess:
sess.flush()
if not has_identity(self.instance):
if not orm_util.has_identity(self.instance):
return None
else:
return sess
def session(self):
return self.__session()
session = property(session, lambda s, x:None)
session = property(session, lambda s, x: None)
def __iter__(self):
sess = self.__session()
sess = self.session
if sess is None:
return iter(self.attr._get_collection_history(
attributes.instance_state(self.instance),
passive=True).added_items)
attributes.PASSIVE_NO_INITIALIZE).added_items)
else:
return iter(self._clone(sess))
def __getitem__(self, index):
sess = self.__session()
sess = self.session
if sess is None:
return self.attr._get_collection_history(
attributes.instance_state(self.instance),
passive=True).added_items.__getitem__(index)
attributes.PASSIVE_NO_INITIALIZE).indexed(index)
else:
return self._clone(sess).__getitem__(index)
def count(self):
sess = self.__session()
sess = self.session
if sess is None:
return len(self.attr._get_collection_history(
attributes.instance_state(self.instance),
passive=True).added_items)
attributes.PASSIVE_NO_INITIALIZE).added_items)
else:
return self._clone(sess).count()
@ -243,26 +275,32 @@ class AppenderMixin(object):
"Parent instance %s is not bound to a Session, and no "
"contextual session is established; lazy load operation "
"of attribute '%s' cannot proceed" % (
mapperutil.instance_str(instance), self.attr.key))
orm_util.instance_str(instance), self.attr.key))
if self.query_class:
query = self.query_class(self.attr.target_mapper, session=sess)
else:
query = sess.query(self.attr.target_mapper)
query._criterion = self._criterion
query._order_by = self._order_by
return query
def extend(self, iterator):
for item in iterator:
self.attr.append(
attributes.instance_state(self.instance),
attributes.instance_dict(self.instance), item, None)
def append(self, item):
self.attr.append(
attributes.instance_state(self.instance),
attributes.instance_state(self.instance),
attributes.instance_dict(self.instance), item, None)
def remove(self, item):
self.attr.remove(
attributes.instance_state(self.instance),
attributes.instance_state(self.instance),
attributes.instance_dict(self.instance), item, None)
@ -275,19 +313,55 @@ def mixin_user_query(cls):
name = 'Appender' + cls.__name__
return type(name, (AppenderMixin, cls), {'query_class': cls})
class CollectionHistory(object):
"""Overrides AttributeHistory to receive append/remove events directly."""
def __init__(self, attr, state, apply_to=None):
if apply_to:
deleted = util.IdentitySet(apply_to.deleted_items)
added = apply_to.added_items
coll = AppenderQuery(attr, state).autoflush(False)
self.unchanged_items = [o for o in util.IdentitySet(coll) if o not in deleted]
self.unchanged_items = util.OrderedIdentitySet(coll)
self.added_items = apply_to.added_items
self.deleted_items = apply_to.deleted_items
self._reconcile_collection = True
else:
self.deleted_items = []
self.added_items = []
self.unchanged_items = []
self.deleted_items = util.OrderedIdentitySet()
self.added_items = util.OrderedIdentitySet()
self.unchanged_items = util.OrderedIdentitySet()
self._reconcile_collection = False
@property
def added_plus_unchanged(self):
return list(self.added_items.union(self.unchanged_items))
@property
def all_items(self):
return list(self.added_items.union(
self.unchanged_items).union(self.deleted_items))
def as_history(self):
if self._reconcile_collection:
added = self.added_items.difference(self.unchanged_items)
deleted = self.deleted_items.intersection(self.unchanged_items)
unchanged = self.unchanged_items.difference(deleted)
else:
added, unchanged, deleted = self.added_items,\
self.unchanged_items,\
self.deleted_items
return attributes.History(
list(added),
list(unchanged),
list(deleted),
)
def indexed(self, index):
return list(self.added_items)[index]
def add_added(self, value):
self.added_items.add(value)
def add_removed(self, value):
if value in self.added_items:
self.added_items.remove(value)
else:
self.deleted_items.add(value)

View File

@ -1,17 +1,21 @@
# orm/evaluator.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import operator
from sqlalchemy.sql import operators, functions
from sqlalchemy.sql import expression as sql
from ..sql import operators
class UnevaluatableError(Exception):
pass
_straight_ops = set(getattr(operators, op)
for op in ('add', 'mul', 'sub',
# Py2K
'div',
# end Py2K
'mod', 'truediv',
for op in ('add', 'mul', 'sub',
'div',
'mod', 'truediv',
'lt', 'le', 'ne', 'gt', 'ge', 'eq'))
@ -20,11 +24,16 @@ _notimplemented_ops = set(getattr(operators, op)
'notilike_op', 'between_op', 'in_op',
'notin_op', 'endswith_op', 'concat_op'))
class EvaluatorCompiler(object):
def __init__(self, target_cls=None):
self.target_cls = target_cls
def process(self, clause):
meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
if not meth:
raise UnevaluatableError("Cannot evaluate %s" % type(clause).__name__)
raise UnevaluatableError(
"Cannot evaluate %s" % type(clause).__name__)
return meth(clause)
def visit_grouping(self, clause):
@ -33,16 +42,30 @@ class EvaluatorCompiler(object):
def visit_null(self, clause):
return lambda obj: None
def visit_false(self, clause):
return lambda obj: False
def visit_true(self, clause):
return lambda obj: True
def visit_column(self, clause):
if 'parentmapper' in clause._annotations:
key = clause._annotations['parentmapper']._get_col_to_prop(clause).key
parentmapper = clause._annotations['parentmapper']
if self.target_cls and not issubclass(
self.target_cls, parentmapper.class_):
raise UnevaluatableError(
"Can't evaluate criteria against alternate class %s" %
parentmapper.class_
)
key = parentmapper._columntoproperty[clause].key
else:
key = clause.key
get_corresponding_attr = operator.attrgetter(key)
return lambda obj: get_corresponding_attr(obj)
def visit_clauselist(self, clause):
evaluators = map(self.process, clause.clauses)
evaluators = list(map(self.process, clause.clauses))
if clause.operator is operators.or_:
def evaluate(obj):
has_null = False
@ -64,12 +87,15 @@ class EvaluatorCompiler(object):
return False
return True
else:
raise UnevaluatableError("Cannot evaluate clauselist with operator %s" % clause.operator)
raise UnevaluatableError(
"Cannot evaluate clauselist with operator %s" %
clause.operator)
return evaluate
def visit_binary(self, clause):
eval_left,eval_right = map(self.process, [clause.left, clause.right])
eval_left, eval_right = list(map(self.process,
[clause.left, clause.right]))
operator = clause.operator
if operator is operators.is_:
def evaluate(obj):
@ -85,7 +111,9 @@ class EvaluatorCompiler(object):
return None
return operator(eval_left(obj), eval_right(obj))
else:
raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator))
raise UnevaluatableError(
"Cannot evaluate %s with operator %s" %
(type(clause).__name__, clause.operator))
return evaluate
def visit_unary(self, clause):
@ -97,8 +125,13 @@ class EvaluatorCompiler(object):
return None
return not value
return evaluate
raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator))
raise UnevaluatableError(
"Cannot evaluate %s with operator %s" %
(type(clause).__name__, clause.operator))
def visit_bindparam(self, clause):
val = clause.value
if clause.callable:
val = clause.callable()
else:
val = clause.value
return lambda obj: val

View File

@ -1,42 +1,79 @@
# exc.py - ORM exceptions
# Copyright (C) the SQLAlchemy authors and contributors
# orm/exc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""SQLAlchemy ORM exceptions."""
import sqlalchemy as sa
from .. import exc as sa_exc, util
NO_STATE = (AttributeError, KeyError)
"""Exception types that may be raised by instrumentation implementations."""
class ConcurrentModificationError(sa.exc.SQLAlchemyError):
"""Rows have been modified outside of the unit of work."""
class StaleDataError(sa_exc.SQLAlchemyError):
"""An operation encountered database state that is unaccounted for.
Conditions which cause this to happen include:
* A flush may have attempted to update or delete rows
and an unexpected number of rows were matched during
the UPDATE or DELETE statement. Note that when
version_id_col is used, rows in UPDATE or DELETE statements
are also matched against the current known version
identifier.
* A mapped object with version_id_col was refreshed,
and the version number coming back from the database does
not match that of the object itself.
* A object is detached from its parent object, however
the object was previously attached to a different parent
identity which was garbage collected, and a decision
cannot be made if the new parent was really the most
recent "parent".
.. versionadded:: 0.7.4
"""
ConcurrentModificationError = StaleDataError
class FlushError(sa.exc.SQLAlchemyError):
class FlushError(sa_exc.SQLAlchemyError):
"""A invalid condition was detected during flush()."""
class UnmappedError(sa.exc.InvalidRequestError):
"""TODO"""
class UnmappedError(sa_exc.InvalidRequestError):
"""Base for exceptions that involve expected mappings not present."""
class ObjectDereferencedError(sa_exc.SQLAlchemyError):
"""An operation cannot complete due to an object being garbage
collected.
"""
class DetachedInstanceError(sa_exc.SQLAlchemyError):
"""An attempt to access unloaded attributes on a
mapped instance that is detached."""
class DetachedInstanceError(sa.exc.SQLAlchemyError):
"""An attempt to access unloaded attributes on a mapped instance that is detached."""
class UnmappedInstanceError(UnmappedError):
"""An mapping operation was requested for an unknown instance."""
def __init__(self, obj, msg=None):
@util.dependencies("sqlalchemy.orm.base")
def __init__(self, base, obj, msg=None):
if not msg:
try:
mapper = sa.orm.class_mapper(type(obj))
base.class_mapper(type(obj))
name = _safe_cls_name(type(obj))
msg = ("Class %r is mapped, but this instance lacks "
"instrumentation. This occurs when the instance is created "
"before sqlalchemy.orm.mapper(%s) was called." % (name, name))
"instrumentation. This occurs when the instance"
"is created before sqlalchemy.orm.mapper(%s) "
"was called." % (name, name))
except UnmappedClassError:
msg = _default_unmapped(type(obj))
if isinstance(obj, type):
@ -45,6 +82,9 @@ class UnmappedInstanceError(UnmappedError):
'required?' % _safe_cls_name(obj))
UnmappedError.__init__(self, msg)
def __reduce__(self):
return self.__class__, (None, self.args[0])
class UnmappedClassError(UnmappedError):
"""An mapping operation was requested for an unknown class."""
@ -54,28 +94,53 @@ class UnmappedClassError(UnmappedError):
msg = _default_unmapped(cls)
UnmappedError.__init__(self, msg)
class ObjectDeletedError(sa.exc.InvalidRequestError):
"""An refresh() operation failed to re-retrieve an object's row."""
def __reduce__(self):
return self.__class__, (None, self.args[0])
class UnmappedColumnError(sa.exc.InvalidRequestError):
class ObjectDeletedError(sa_exc.InvalidRequestError):
"""A refresh operation failed to retrieve the database
row corresponding to an object's known primary key identity.
A refresh operation proceeds when an expired attribute is
accessed on an object, or when :meth:`.Query.get` is
used to retrieve an object which is, upon retrieval, detected
as expired. A SELECT is emitted for the target row
based on primary key; if no row is returned, this
exception is raised.
The true meaning of this exception is simply that
no row exists for the primary key identifier associated
with a persistent object. The row may have been
deleted, or in some cases the primary key updated
to a new value, outside of the ORM's management of the target
object.
"""
@util.dependencies("sqlalchemy.orm.base")
def __init__(self, base, state, msg=None):
if not msg:
msg = "Instance '%s' has been deleted, or its "\
"row is otherwise not present." % base.state_str(state)
sa_exc.InvalidRequestError.__init__(self, msg)
def __reduce__(self):
return self.__class__, (None, self.args[0])
class UnmappedColumnError(sa_exc.InvalidRequestError):
"""Mapping operation was requested on an unknown column."""
class NoResultFound(sa.exc.InvalidRequestError):
class NoResultFound(sa_exc.InvalidRequestError):
"""A database result was required but none was found."""
class MultipleResultsFound(sa.exc.InvalidRequestError):
class MultipleResultsFound(sa_exc.InvalidRequestError):
"""A single database result was required but more than one were found."""
# Legacy compat until 0.6.
sa.exc.ConcurrentModificationError = ConcurrentModificationError
sa.exc.FlushError = FlushError
sa.exc.UnmappedColumnError
def _safe_cls_name(cls):
try:
cls_name = '.'.join((cls.__module__, cls.__name__))
@ -85,9 +150,11 @@ def _safe_cls_name(cls):
cls_name = repr(cls)
return cls_name
def _default_unmapped(cls):
@util.dependencies("sqlalchemy.orm.base")
def _default_unmapped(base, cls):
try:
mappers = sa.orm.attributes.manager_of_class(cls).mappers
mappers = base.manager_of_class(cls).mappers
except NO_STATE:
mappers = {}
except TypeError:

View File

@ -1,67 +1,66 @@
# identity.py
# Copyright (C) the SQLAlchemy authors and contributors
# orm/identity.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import weakref
from . import attributes
from .. import util
from .. import exc as sa_exc
from . import util as orm_util
from sqlalchemy import util as base_util
from sqlalchemy.orm import attributes
class IdentityMap(dict):
class IdentityMap(object):
def __init__(self):
self._mutable_attrs = set()
self._dict = {}
self._modified = set()
self._wr = weakref.ref(self)
def keys(self):
return self._dict.keys()
def replace(self, state):
raise NotImplementedError()
def add(self, state):
raise NotImplementedError()
def remove(self, state):
raise NotImplementedError()
def _add_unpresent(self, state, key):
"""optional inlined form of add() which can assume item isn't present
in the map"""
self.add(state)
def update(self, dict):
raise NotImplementedError("IdentityMap uses add() to insert data")
def clear(self):
raise NotImplementedError("IdentityMap uses remove() to remove data")
def _manage_incoming_state(self, state):
state._instance_dict = self._wr
if state.modified:
self._modified.add(state)
if state.manager.mutable_attributes:
self._mutable_attrs.add(state)
self._modified.add(state)
def _manage_removed_state(self, state):
del state._instance_dict
self._mutable_attrs.discard(state)
self._modified.discard(state)
if state.modified:
self._modified.discard(state)
def _dirty_states(self):
return self._modified.union(s for s in self._mutable_attrs.copy()
if s.modified)
return self._modified
def check_modified(self):
"""return True if any InstanceStates present have been marked as 'modified'."""
if self._modified:
return True
else:
for state in self._mutable_attrs.copy():
if state.modified:
return True
return False
"""return True if any InstanceStates present have been marked
as 'modified'.
"""
return bool(self._modified)
def has_key(self, key):
return key in self
def popitem(self):
raise NotImplementedError("IdentityMap uses remove() to remove data")
@ -71,6 +70,9 @@ class IdentityMap(dict):
def setdefault(self, key, default=None):
raise NotImplementedError("IdentityMap uses add() to insert data")
def __len__(self):
return len(self._dict)
def copy(self):
raise NotImplementedError()
@ -79,164 +81,233 @@ class IdentityMap(dict):
def __delitem__(self, key):
raise NotImplementedError("IdentityMap uses remove() to remove data")
class WeakInstanceDict(IdentityMap):
def __getitem__(self, key):
state = dict.__getitem__(self, key)
state = self._dict[key]
o = state.obj()
if o is None:
o = state._is_really_none()
if o is None:
raise KeyError, key
raise KeyError(key)
return o
def __contains__(self, key):
try:
if dict.__contains__(self, key):
state = dict.__getitem__(self, key)
if key in self._dict:
state = self._dict[key]
o = state.obj()
if o is None:
o = state._is_really_none()
else:
return False
except KeyError:
return False
else:
return o is not None
def contains_state(self, state):
return dict.get(self, state.key) is state
return state.key in self._dict and self._dict[state.key] is state
def replace(self, state):
if dict.__contains__(self, state.key):
existing = dict.__getitem__(self, state.key)
if state.key in self._dict:
existing = self._dict[state.key]
if existing is not state:
self._manage_removed_state(existing)
else:
return
dict.__setitem__(self, state.key, state)
self._dict[state.key] = state
self._manage_incoming_state(state)
def add(self, state):
if state.key in self:
if dict.__getitem__(self, state.key) is not state:
raise AssertionError("A conflicting state is already "
"present in the identity map for key %r"
% (state.key, ))
else:
dict.__setitem__(self, state.key, state)
self._manage_incoming_state(state)
def remove_key(self, key):
state = dict.__getitem__(self, key)
self.remove(state)
def remove(self, state):
if dict.pop(self, state.key) is not state:
raise AssertionError("State %s is not present in this identity map" % state)
self._manage_removed_state(state)
def discard(self, state):
if self.contains_state(state):
dict.__delitem__(self, state.key)
self._manage_removed_state(state)
key = state.key
# inline of self.__contains__
if key in self._dict:
try:
existing_state = self._dict[key]
if existing_state is not state:
o = existing_state.obj()
if o is not None:
raise sa_exc.InvalidRequestError(
"Can't attach instance "
"%s; another instance with key %s is already "
"present in this session." % (
orm_util.state_str(state), state.key))
else:
return False
except KeyError:
pass
self._dict[key] = state
self._manage_incoming_state(state)
return True
def _add_unpresent(self, state, key):
# inlined form of add() called by loading.py
self._dict[key] = state
state._instance_dict = self._wr
def get(self, key, default=None):
state = dict.get(self, key, default)
if state is default:
if key not in self._dict:
return default
state = self._dict[key]
o = state.obj()
if o is None:
o = state._is_really_none()
if o is None:
return default
return o
# Py2K
def items(self):
return list(self.iteritems())
def iteritems(self):
for state in dict.itervalues(self):
# end Py2K
# Py3K
#def items(self):
# for state in dict.values(self):
def items(self):
values = self.all_states()
result = []
for state in values:
value = state.obj()
if value is not None:
yield state.key, value
result.append((state.key, value))
return result
# Py2K
def values(self):
return list(self.itervalues())
values = self.all_states()
result = []
for state in values:
value = state.obj()
if value is not None:
result.append(value)
def itervalues(self):
for state in dict.itervalues(self):
# end Py2K
# Py3K
#def values(self):
# for state in dict.values(self):
instance = state.obj()
if instance is not None:
yield instance
return result
def __iter__(self):
return iter(self.keys())
if util.py2k:
def iteritems(self):
return iter(self.items())
def itervalues(self):
return iter(self.values())
def all_states(self):
# Py3K
# return list(dict.values(self))
# Py2K
return dict.values(self)
# end Py2K
if util.py2k:
return self._dict.values()
else:
return list(self._dict.values())
def _fast_discard(self, state):
self._dict.pop(state.key, None)
def discard(self, state):
st = self._dict.pop(state.key, None)
if st:
assert st is state
self._manage_removed_state(state)
def safe_discard(self, state):
if state.key in self._dict:
st = self._dict[state.key]
if st is state:
self._dict.pop(state.key, None)
self._manage_removed_state(state)
def prune(self):
return 0
class StrongInstanceDict(IdentityMap):
"""A 'strong-referencing' version of the identity map.
.. deprecated 1.1::
The strong
reference identity map is legacy. See the
recipe at :ref:`session_referencing_behavior` for
an event-based approach to maintaining strong identity
references.
"""
if util.py2k:
def itervalues(self):
return self._dict.itervalues()
def iteritems(self):
return self._dict.iteritems()
def __iter__(self):
return iter(self.dict_)
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def get(self, key, default=None):
return self._dict.get(key, default)
def values(self):
return self._dict.values()
def items(self):
return self._dict.items()
def all_states(self):
return [attributes.instance_state(o) for o in self.itervalues()]
return [attributes.instance_state(o) for o in self.values()]
def contains_state(self, state):
return state.key in self and attributes.instance_state(self[state.key]) is state
return (
state.key in self and
attributes.instance_state(self[state.key]) is state)
def replace(self, state):
if dict.__contains__(self, state.key):
existing = dict.__getitem__(self, state.key)
if state.key in self._dict:
existing = self._dict[state.key]
existing = attributes.instance_state(existing)
if existing is not state:
self._manage_removed_state(existing)
else:
return
dict.__setitem__(self, state.key, state.obj())
self._dict[state.key] = state.obj()
self._manage_incoming_state(state)
def add(self, state):
if state.key in self:
if attributes.instance_state(dict.__getitem__(self, state.key)) is not state:
raise AssertionError("A conflicting state is already present in the identity map for key %r" % (state.key, ))
if attributes.instance_state(self._dict[state.key]) is not state:
raise sa_exc.InvalidRequestError(
"Can't attach instance "
"%s; another instance with key %s is already "
"present in this session." % (
orm_util.state_str(state), state.key))
return False
else:
dict.__setitem__(self, state.key, state.obj())
self._dict[state.key] = state.obj()
self._manage_incoming_state(state)
def remove(self, state):
if attributes.instance_state(dict.pop(self, state.key)) is not state:
raise AssertionError("State %s is not present in this identity map" % state)
self._manage_removed_state(state)
return True
def _add_unpresent(self, state, key):
# inlined form of add() called by loading.py
self._dict[key] = state.obj()
state._instance_dict = self._wr
def _fast_discard(self, state):
self._dict.pop(state.key, None)
def discard(self, state):
if self.contains_state(state):
dict.__delitem__(self, state.key)
obj = self._dict.pop(state.key, None)
if obj is not None:
self._manage_removed_state(state)
def remove_key(self, key):
state = attributes.instance_state(dict.__getitem__(self, key))
self.remove(state)
st = attributes.instance_state(obj)
assert st is state
def safe_discard(self, state):
if state.key in self._dict:
obj = self._dict[state.key]
st = attributes.instance_state(obj)
if st is state:
self._dict.pop(state.key, None)
self._manage_removed_state(state)
def prune(self):
"""prune unreferenced, non-dirty states."""
ref_count = len(self)
dirty = [s.obj() for s in self.all_states() if s.modified]
@ -244,8 +315,7 @@ class StrongInstanceDict(IdentityMap):
keepers = weakref.WeakValueDictionary()
keepers.update(self)
dict.clear(self)
dict.update(self, keepers)
self._dict.clear()
self._dict.update(keepers)
self.modified = bool(dirty)
return ref_count - len(self)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,96 +1,120 @@
# scoping.py
# Copyright (C) the SQLAlchemy authors and contributors
# orm/scoping.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sqlalchemy.exceptions as sa_exc
from sqlalchemy.util import ScopedRegistry, ThreadLocalRegistry, \
to_list, get_cls_kwargs, deprecated
from sqlalchemy.orm import (
EXT_CONTINUE, MapperExtension, class_mapper, object_session
)
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm.session import Session
from .. import exc as sa_exc
from ..util import ScopedRegistry, ThreadLocalRegistry, warn
from . import class_mapper, exc as orm_exc
from .session import Session
__all__ = ['ScopedSession']
__all__ = ['scoped_session']
class ScopedSession(object):
"""Provides thread-local management of Sessions.
class scoped_session(object):
"""Provides scoped management of :class:`.Session` objects.
Usage::
Session = scoped_session(sessionmaker(autoflush=True))
... use session normally.
See :ref:`unitofwork_contextual` for a tutorial.
"""
session_factory = None
"""The `session_factory` provided to `__init__` is stored in this
attribute and may be accessed at a later time. This can be useful when
a new non-scoped :class:`.Session` or :class:`.Connection` to the
database is needed."""
def __init__(self, session_factory, scopefunc=None):
"""Construct a new :class:`.scoped_session`.
:param session_factory: a factory to create new :class:`.Session`
instances. This is usually, but not necessarily, an instance
of :class:`.sessionmaker`.
:param scopefunc: optional function which defines
the current scope. If not passed, the :class:`.scoped_session`
object assumes "thread-local" scope, and will use
a Python ``threading.local()`` in order to maintain the current
:class:`.Session`. If passed, the function should return
a hashable token; this token will be used as the key in a
dictionary in order to store and retrieve the current
:class:`.Session`.
"""
self.session_factory = session_factory
if scopefunc:
self.registry = ScopedRegistry(session_factory, scopefunc)
else:
self.registry = ThreadLocalRegistry(session_factory)
self.extension = _ScopedExt(self)
def __call__(self, **kwargs):
if kwargs:
scope = kwargs.pop('scope', False)
def __call__(self, **kw):
r"""Return the current :class:`.Session`, creating it
using the :attr:`.scoped_session.session_factory` if not present.
:param \**kw: Keyword arguments will be passed to the
:attr:`.scoped_session.session_factory` callable, if an existing
:class:`.Session` is not present. If the :class:`.Session` is present
and keyword arguments have been passed,
:exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
"""
if kw:
scope = kw.pop('scope', False)
if scope is not None:
if self.registry.has():
raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
raise sa_exc.InvalidRequestError(
"Scoped session is already present; "
"no new arguments may be specified.")
else:
sess = self.session_factory(**kwargs)
sess = self.session_factory(**kw)
self.registry.set(sess)
return sess
else:
return self.session_factory(**kwargs)
return self.session_factory(**kw)
else:
return self.registry()
def remove(self):
"""Dispose of the current contextual session."""
"""Dispose of the current :class:`.Session`, if present.
This will first call :meth:`.Session.close` method
on the current :class:`.Session`, which releases any existing
transactional/connection resources still being held; transactions
specifically are rolled back. The :class:`.Session` is then
discarded. Upon next usage within the same scope,
the :class:`.scoped_session` will produce a new
:class:`.Session` object.
"""
if self.registry.has():
self.registry().close()
self.registry.clear()
@deprecated("Session.mapper is deprecated. "
"Please see http://www.sqlalchemy.org/trac/wiki/UsageRecipes/SessionAwareMapper "
"for information on how to replicate its behavior.")
def mapper(self, *args, **kwargs):
"""return a mapper() function which associates this ScopedSession with the Mapper.
def configure(self, **kwargs):
"""reconfigure the :class:`.sessionmaker` used by this
:class:`.scoped_session`.
DEPRECATED.
See :meth:`.sessionmaker.configure`.
"""
from sqlalchemy.orm import mapper
extension_args = dict((arg, kwargs.pop(arg))
for arg in get_cls_kwargs(_ScopedExt)
if arg in kwargs)
kwargs['extension'] = extension = to_list(kwargs.get('extension', []))
if extension_args:
extension.append(self.extension.configure(**extension_args))
else:
extension.append(self.extension)
return mapper(*args, **kwargs)
def configure(self, **kwargs):
"""reconfigure the sessionmaker used by this ScopedSession."""
if self.registry.has():
warn('At least one scoped session is already present. '
' configure() can not affect sessions that have '
'already been created.')
self.session_factory.configure(**kwargs)
def query_property(self, query_cls=None):
"""return a class property which produces a `Query` object against the
class when called.
"""return a class property which produces a :class:`.Query` object
against the class and the current :class:`.Session` when called.
e.g.::
Session = scoped_session(sessionmaker())
class MyClass(object):
@ -124,82 +148,37 @@ class ScopedSession(object):
return None
return query()
ScopedSession = scoped_session
"""Old name for backwards compatibility."""
def instrument(name):
def do(self, *args, **kwargs):
return getattr(self.registry(), name)(*args, **kwargs)
return do
for meth in Session.public_methods:
setattr(ScopedSession, meth, instrument(meth))
setattr(scoped_session, meth, instrument(meth))
def makeprop(name):
def set(self, attr):
setattr(self.registry(), name, attr)
def get(self):
return getattr(self.registry(), name)
return property(get, set)
for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', 'is_active', 'autoflush'):
setattr(ScopedSession, prop, makeprop(prop))
for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map',
'is_active', 'autoflush', 'no_autoflush', 'info'):
setattr(scoped_session, prop, makeprop(prop))
def clslevel(name):
def do(cls, *args, **kwargs):
return getattr(Session, name)(*args, **kwargs)
return classmethod(do)
for prop in ('close_all', 'object_session', 'identity_key'):
setattr(ScopedSession, prop, clslevel(prop))
class _ScopedExt(MapperExtension):
def __init__(self, context, validate=False, save_on_init=True):
self.context = context
self.validate = validate
self.save_on_init = save_on_init
self.set_kwargs_on_init = True
def validating(self):
return _ScopedExt(self.context, validate=True)
def configure(self, **kwargs):
return _ScopedExt(self.context, **kwargs)
def instrument_class(self, mapper, class_):
class query(object):
def __getattr__(s, key):
return getattr(self.context.registry().query(class_), key)
def __call__(s):
return self.context.registry().query(class_)
def __get__(self, instance, cls):
return self
if not 'query' in class_.__dict__:
class_.query = query()
if self.set_kwargs_on_init and class_.__init__ is object.__init__:
class_.__init__ = self._default__init__(mapper)
def _default__init__(ext, mapper):
def __init__(self, **kwargs):
for key, value in kwargs.iteritems():
if ext.validate:
if not mapper.get_property(key, resolve_synonyms=False,
raiseerr=False):
raise sa_exc.ArgumentError(
"Invalid __init__ argument: '%s'" % key)
setattr(self, key, value)
return __init__
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
if self.save_on_init:
session = kwargs.pop('_sa_session', None)
if session is None:
session = self.context.registry()
session._save_without_cascade(instance)
return EXT_CONTINUE
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
sess = object_session(instance)
if sess:
sess.expunge(instance)
return EXT_CONTINUE
def dispose_class(self, mapper, class_):
if hasattr(class_, 'query'):
delattr(class_, 'query')
setattr(scoped_session, prop, clslevel(prop))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,98 +1,140 @@
# mapper/sync.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
# orm/sync.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""private module containing functions used for copying data
"""private module containing functions used for copying data
between instances based on join conditions.
"""
from sqlalchemy.orm import exc, util as mapperutil
from . import exc, util as orm_util, attributes
def populate(source, source_mapper, dest, dest_mapper,
synchronize_pairs, uowcommit, flag_cascaded_pks):
source_dict = source.dict
dest_dict = dest.dict
def populate(source, source_mapper, dest, dest_mapper,
synchronize_pairs, uowcommit, passive_updates):
for l, r in synchronize_pairs:
try:
value = source_mapper._get_state_attr_by_column(source, l)
# inline of source_mapper._get_state_attr_by_column
prop = source_mapper._columntoproperty[l]
value = source.manager[prop.key].impl.get(source, source_dict,
attributes.PASSIVE_OFF)
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
try:
dest_mapper._set_state_attr_by_column(dest, r, value)
# inline of dest_mapper._set_state_attr_by_column
prop = dest_mapper._columntoproperty[r]
dest.manager[prop.key].impl.set(dest, dest_dict, value, None)
except exc.UnmappedColumnError:
_raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
# techically the "r.primary_key" check isn't
# technically the "r.primary_key" check isn't
# needed here, but we check for this condition to limit
# how often this logic is invoked for memory/performance
# reasons, since we only need this info for a primary key
# destination.
if l.primary_key and r.primary_key and \
r.references(l) and passive_updates:
if flag_cascaded_pks and l.primary_key and \
r.primary_key and \
r.references(l):
uowcommit.attributes[("pk_cascaded", dest, r)] = True
def bulk_populate_inherit_keys(
source_dict, source_mapper, synchronize_pairs):
# a simplified version of populate() used by bulk insert mode
for l, r in synchronize_pairs:
try:
prop = source_mapper._columntoproperty[l]
value = source_dict[prop.key]
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, source_mapper, r)
try:
prop = source_mapper._columntoproperty[r]
source_dict[prop.key] = value
except exc.UnmappedColumnError:
_raise_col_to_prop(True, source_mapper, l, source_mapper, r)
def clear(dest, dest_mapper, synchronize_pairs):
for l, r in synchronize_pairs:
if r.primary_key:
if r.primary_key and \
dest_mapper._get_state_attr_by_column(
dest, dest.dict, r) not in orm_util._none_set:
raise AssertionError(
"Dependency rule tried to blank-out primary key "
"column '%s' on instance '%s'" %
(r, mapperutil.state_str(dest))
)
"Dependency rule tried to blank-out primary key "
"column '%s' on instance '%s'" %
(r, orm_util.state_str(dest))
)
try:
dest_mapper._set_state_attr_by_column(dest, r, None)
dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None)
except exc.UnmappedColumnError:
_raise_col_to_prop(True, None, l, dest_mapper, r)
def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
for l, r in synchronize_pairs:
try:
oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l)
value = source_mapper._get_state_attr_by_column(source, l)
oldvalue = source_mapper._get_committed_attr_by_column(
source.obj(), l)
value = source_mapper._get_state_attr_by_column(
source, source.dict, l, passive=attributes.PASSIVE_OFF)
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
dest[r.key] = value
dest[old_prefix + r.key] = oldvalue
def populate_dict(source, source_mapper, dict_, synchronize_pairs):
for l, r in synchronize_pairs:
try:
value = source_mapper._get_state_attr_by_column(source, l)
value = source_mapper._get_state_attr_by_column(
source, source.dict, l, passive=attributes.PASSIVE_OFF)
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
dict_[r.key] = value
def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
"""return true if the source object has changes from an old to a
"""return true if the source object has changes from an old to a
new value on the given synchronize pairs
"""
for l, r in synchronize_pairs:
try:
prop = source_mapper._get_col_to_prop(l)
prop = source_mapper._columntoproperty[l]
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
history = uowcommit.get_attribute_history(source, prop.key, passive=True)
if len(history.deleted):
history = uowcommit.get_attribute_history(
source, prop.key, attributes.PASSIVE_NO_INITIALIZE)
if bool(history.deleted):
return True
else:
return False
def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column):
def _raise_col_to_prop(isdest, source_mapper, source_column,
dest_mapper, dest_column):
if isdest:
raise exc.UnmappedColumnError(
"Can't execute sync rule for destination column '%s'; "
"mapper '%s' does not map this column. Try using an explicit"
" `foreign_keys` collection which does not include this column "
"(or use a viewonly=True relation)." % (dest_column, source_mapper)
)
"Can't execute sync rule for "
"destination column '%s'; mapper '%s' does not map "
"this column. Try using an explicit `foreign_keys` "
"collection which does not include this column (or use "
"a viewonly=True relation)." % (dest_column, dest_mapper))
else:
raise exc.UnmappedColumnError(
"Can't execute sync rule for source column '%s'; mapper '%s' "
"does not map this column. Try using an explicit `foreign_keys`"
" collection which does not include destination column '%s' (or "
"use a viewonly=True relation)." %
(source_column, source_mapper, dest_column)
)
"Can't execute sync rule for "
"source column '%s'; mapper '%s' does not map this "
"column. Try using an explicit `foreign_keys` "
"collection which does not include destination column "
"'%s' (or use a viewonly=True relation)." %
(source_column, source_mapper, dest_column))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,12 @@
# processors.py
# sqlalchemy/processors.py
# Copyright (C) 2010-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""defines generic type conversion functions, as used in bind and result
"""defines generic type conversion functions, as used in bind and result
processors.
They all share one common characteristic: None is passed through unchanged.
@ -14,41 +16,47 @@ They all share one common characteristic: None is passed through unchanged.
import codecs
import re
import datetime
from . import util
def str_to_datetime_processor_factory(regexp, type_):
rmatch = regexp.match
# Even on python2.6 datetime.strptime is both slower than this code
# and it does not support microseconds.
has_named_groups = bool(regexp.groupindex)
def process(value):
if value is None:
return None
else:
return type_(*map(int, rmatch(value).groups(0)))
try:
m = rmatch(value)
except TypeError:
raise ValueError("Couldn't parse %s string '%r' "
"- value is not a string." %
(type_.__name__, value))
if m is None:
raise ValueError("Couldn't parse %s string: "
"'%s'" % (type_.__name__, value))
if has_named_groups:
groups = m.groupdict(0)
return type_(**dict(list(zip(
iter(groups.keys()),
list(map(int, iter(groups.values())))
))))
else:
return type_(*list(map(int, m.groups(0))))
return process
try:
from sqlalchemy.cprocessors import UnicodeResultProcessor, \
DecimalResultProcessor, \
to_float, to_str, int_to_boolean, \
str_to_datetime, str_to_time, \
str_to_date
def to_unicode_processor_factory(encoding, errors=None):
# this is cumbersome but it would be even more so on the C side
if errors is not None:
return UnicodeResultProcessor(encoding, errors).process
else:
return UnicodeResultProcessor(encoding).process
def to_decimal_processor_factory(target_class, scale=10):
# Note that the scale argument is not taken into account for integer
# values in the C implementation while it is in the Python one.
# For example, the Python implementation might return
# Decimal('5.00000') whereas the C implementation will
# return Decimal('5'). These are equivalent of course.
return DecimalResultProcessor(target_class, "%%.%df" % scale).process
def boolean_to_int(value):
if value is None:
return None
else:
return int(bool(value))
except ImportError:
def py_fallback():
def to_unicode_processor_factory(encoding, errors=None):
decoder = codecs.getdecoder(encoding)
@ -62,7 +70,22 @@ except ImportError:
return decoder(value, errors)[0]
return process
def to_decimal_processor_factory(target_class, scale=10):
def to_conditional_unicode_processor_factory(encoding, errors=None):
decoder = codecs.getdecoder(encoding)
def process(value):
if value is None:
return None
elif isinstance(value, util.text_type):
return value
else:
# decoder returns a tuple: (value, len). Simply dropping the
# len part is safe: it is done that way in the normal
# 'xx'.decode(encoding) code path.
return decoder(value, errors)[0]
return process
def to_decimal_processor_factory(target_class, scale):
fstring = "%%.%df" % scale
def process(value):
@ -88,14 +111,45 @@ except ImportError:
if value is None:
return None
else:
return value and True or False
return bool(value)
DATETIME_RE = re.compile("(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?")
TIME_RE = re.compile("(\d+):(\d+):(\d+)(?:\.(\d+))?")
DATE_RE = re.compile("(\d+)-(\d+)-(\d+)")
DATETIME_RE = re.compile(
r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?")
TIME_RE = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
DATE_RE = re.compile(r"(\d+)-(\d+)-(\d+)")
str_to_datetime = str_to_datetime_processor_factory(DATETIME_RE,
datetime.datetime)
str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time)
str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date)
return locals()
try:
from sqlalchemy.cprocessors import UnicodeResultProcessor, \
DecimalResultProcessor, \
to_float, to_str, int_to_boolean, \
str_to_datetime, str_to_time, \
str_to_date
def to_unicode_processor_factory(encoding, errors=None):
if errors is not None:
return UnicodeResultProcessor(encoding, errors).process
else:
return UnicodeResultProcessor(encoding).process
def to_conditional_unicode_processor_factory(encoding, errors=None):
if errors is not None:
return UnicodeResultProcessor(encoding, errors).conditional_process
else:
return UnicodeResultProcessor(encoding).conditional_process
def to_decimal_processor_factory(target_class, scale):
# Note that the scale argument is not taken into account for integer
# values in the C implementation while it is in the Python one.
# For example, the Python implementation might return
# Decimal('5.00000') whereas the C implementation will
# return Decimal('5'). These are equivalent of course.
return DecimalResultProcessor(target_class, "%%.%df" % scale).process
except ImportError:
globals().update(py_fallback())

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,11 @@
from sqlalchemy.sql.expression import (
# sql/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .expression import (
Alias,
ClauseElement,
ColumnCollection,
@ -11,9 +18,12 @@ from sqlalchemy.sql.expression import (
Select,
Selectable,
TableClause,
TableSample,
Update,
alias,
and_,
any_,
all_,
asc,
between,
bindparam,
@ -28,12 +38,16 @@ from sqlalchemy.sql.expression import (
except_all,
exists,
extract,
false,
False_,
func,
funcfilter,
insert,
intersect,
intersect_all,
join,
label,
lateral,
literal,
literal_column,
modifier,
@ -42,17 +56,43 @@ from sqlalchemy.sql.expression import (
or_,
outerjoin,
outparam,
over,
select,
subquery,
table,
tablesample,
text,
true,
True_,
tuple_,
type_coerce,
union,
union_all,
update,
)
within_group
)
from sqlalchemy.sql.visitors import ClauseVisitor
from .visitors import ClauseVisitor
__tmp = locals().keys()
__all__ = sorted([i for i in __tmp if not i.startswith('__')])
def __go(lcls):
global __all__
from .. import util as _sa_util
import inspect as _inspect
__all__ = sorted(name for name, obj in lcls.items()
if not (name.startswith('_') or _inspect.ismodule(obj)))
from .annotation import _prepare_annotations, Annotated
from .elements import AnnotatedColumnElement, ClauseList
from .selectable import AnnotatedFromClause
_prepare_annotations(ColumnElement, AnnotatedColumnElement)
_prepare_annotations(FromClause, AnnotatedFromClause)
_prepare_annotations(ClauseList, Annotated)
_sa_util.dependencies.resolve_all("sqlalchemy.sql")
from . import naming
__go(locals())

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,104 +1,813 @@
from sqlalchemy import types as sqltypes
from sqlalchemy.sql.expression import (
ClauseList, Function, _literal_as_binds, text, _type_from_args
)
from sqlalchemy.sql import operators
from sqlalchemy.sql.visitors import VisitableType
# sql/functions.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""SQL function API, factories, and built-in functions.
"""
from . import sqltypes, schema
from .base import Executable, ColumnCollection
from .elements import ClauseList, Cast, Extract, _literal_as_binds, \
literal_column, _type_from_args, ColumnElement, _clone,\
Over, BindParameter, FunctionFilter, Grouping, WithinGroup
from .selectable import FromClause, Select, Alias
from . import util as sqlutil
from . import operators
from .visitors import VisitableType
from .. import util
from . import annotation
_registry = util.defaultdict(dict)
def register_function(identifier, fn, package="_default"):
"""Associate a callable with a particular func. name.
This is normally called by _GenericMeta, but is also
available by itself so that a non-Function construct
can be associated with the :data:`.func` accessor (i.e.
CAST, EXTRACT).
"""
reg = _registry[package]
reg[identifier] = fn
class FunctionElement(Executable, ColumnElement, FromClause):
"""Base for SQL function-oriented constructs.
.. seealso::
:class:`.Function` - named SQL function.
:data:`.func` - namespace which produces registered or ad-hoc
:class:`.Function` instances.
:class:`.GenericFunction` - allows creation of registered function
types.
"""
packagenames = ()
def __init__(self, *clauses, **kwargs):
"""Construct a :class:`.FunctionElement`.
"""
args = [_literal_as_binds(c, self.name) for c in clauses]
self.clause_expr = ClauseList(
operator=operators.comma_op,
group_contents=True, *args).\
self_group()
def _execute_on_connection(self, connection, multiparams, params):
return connection._execute_function(self, multiparams, params)
@property
def columns(self):
"""The set of columns exported by this :class:`.FunctionElement`.
Function objects currently have no result column names built in;
this method returns a single-element column collection with
an anonymously named column.
An interim approach to providing named columns for a function
as a FROM clause is to build a :func:`.select` with the
desired columns::
from sqlalchemy.sql import column
stmt = select([column('x'), column('y')]).\
select_from(func.myfunction())
"""
return ColumnCollection(self.label(None))
@util.memoized_property
def clauses(self):
"""Return the underlying :class:`.ClauseList` which contains
the arguments for this :class:`.FunctionElement`.
"""
return self.clause_expr.element
def over(self, partition_by=None, order_by=None, rows=None, range_=None):
"""Produce an OVER clause against this function.
Used against aggregate or so-called "window" functions,
for database backends that support window functions.
The expression::
func.row_number().over(order_by='x')
is shorthand for::
from sqlalchemy import over
over(func.row_number(), order_by='x')
See :func:`~.expression.over` for a full description.
.. versionadded:: 0.7
"""
return Over(
self,
partition_by=partition_by,
order_by=order_by,
rows=rows,
range_=range_
)
def within_group(self, *order_by):
"""Produce a WITHIN GROUP (ORDER BY expr) clause against this function.
Used against so-called "ordered set aggregate" and "hypothetical
set aggregate" functions, including :class:`.percentile_cont`,
:class:`.rank`, :class:`.dense_rank`, etc.
See :func:`~.expression.within_group` for a full description.
.. versionadded:: 1.1
"""
return WithinGroup(self, *order_by)
def filter(self, *criterion):
"""Produce a FILTER clause against this function.
Used against aggregate and window functions,
for database backends that support the "FILTER" clause.
The expression::
func.count(1).filter(True)
is shorthand for::
from sqlalchemy import funcfilter
funcfilter(func.count(1), True)
.. versionadded:: 1.0.0
.. seealso::
:class:`.FunctionFilter`
:func:`.funcfilter`
"""
if not criterion:
return self
return FunctionFilter(self, *criterion)
@property
def _from_objects(self):
return self.clauses._from_objects
def get_children(self, **kwargs):
return self.clause_expr,
def _copy_internals(self, clone=_clone, **kw):
self.clause_expr = clone(self.clause_expr, **kw)
self._reset_exported()
FunctionElement.clauses._reset(self)
def within_group_type(self, within_group):
"""For types that define their return type as based on the criteria
within a WITHIN GROUP (ORDER BY) expression, called by the
:class:`.WithinGroup` construct.
Returns None by default, in which case the function's normal ``.type``
is used.
"""
return None
def alias(self, name=None, flat=False):
r"""Produce a :class:`.Alias` construct against this
:class:`.FunctionElement`.
This construct wraps the function in a named alias which
is suitable for the FROM clause, in the style accepted for example
by PostgreSQL.
e.g.::
from sqlalchemy.sql import column
stmt = select([column('data_view')]).\
select_from(SomeTable).\
select_from(func.unnest(SomeTable.data).alias('data_view')
)
Would produce:
.. sourcecode:: sql
SELECT data_view
FROM sometable, unnest(sometable.data) AS data_view
.. versionadded:: 0.9.8 The :meth:`.FunctionElement.alias` method
is now supported. Previously, this method's behavior was
undefined and did not behave consistently across versions.
"""
return Alias(self, name)
def select(self):
"""Produce a :func:`~.expression.select` construct
against this :class:`.FunctionElement`.
This is shorthand for::
s = select([function_element])
"""
s = Select([self])
if self._execution_options:
s = s.execution_options(**self._execution_options)
return s
def scalar(self):
"""Execute this :class:`.FunctionElement` against an embedded
'bind' and return a scalar value.
This first calls :meth:`~.FunctionElement.select` to
produce a SELECT construct.
Note that :class:`.FunctionElement` can be passed to
the :meth:`.Connectable.scalar` method of :class:`.Connection`
or :class:`.Engine`.
"""
return self.select().execute().scalar()
def execute(self):
"""Execute this :class:`.FunctionElement` against an embedded
'bind'.
This first calls :meth:`~.FunctionElement.select` to
produce a SELECT construct.
Note that :class:`.FunctionElement` can be passed to
the :meth:`.Connectable.execute` method of :class:`.Connection`
or :class:`.Engine`.
"""
return self.select().execute()
def _bind_param(self, operator, obj, type_=None):
return BindParameter(None, obj, _compared_to_operator=operator,
_compared_to_type=self.type, unique=True,
type_=type_)
def self_group(self, against=None):
# for the moment, we are parenthesizing all array-returning
# expressions against getitem. This may need to be made
# more portable if in the future we support other DBs
# besides postgresql.
if against is operators.getitem and \
isinstance(self.type, sqltypes.ARRAY):
return Grouping(self)
else:
return super(FunctionElement, self).self_group(against=against)
class _FunctionGenerator(object):
"""Generate :class:`.Function` objects based on getattr calls."""
def __init__(self, **opts):
self.__names = []
self.opts = opts
def __getattr__(self, name):
# passthru __ attributes; fixes pydoc
if name.startswith('__'):
try:
return self.__dict__[name]
except KeyError:
raise AttributeError(name)
elif name.endswith('_'):
name = name[0:-1]
f = _FunctionGenerator(**self.opts)
f.__names = list(self.__names) + [name]
return f
def __call__(self, *c, **kwargs):
o = self.opts.copy()
o.update(kwargs)
tokens = len(self.__names)
if tokens == 2:
package, fname = self.__names
elif tokens == 1:
package, fname = "_default", self.__names[0]
else:
package = None
if package is not None:
func = _registry[package].get(fname)
if func is not None:
return func(*c, **o)
return Function(self.__names[-1],
packagenames=self.__names[0:-1], *c, **o)
func = _FunctionGenerator()
"""Generate SQL function expressions.
:data:`.func` is a special object instance which generates SQL
functions based on name-based attributes, e.g.::
>>> print(func.count(1))
count(:param_1)
The element is a column-oriented SQL element like any other, and is
used in that way::
>>> print(select([func.count(table.c.id)]))
SELECT count(sometable.id) FROM sometable
Any name can be given to :data:`.func`. If the function name is unknown to
SQLAlchemy, it will be rendered exactly as is. For common SQL functions
which SQLAlchemy is aware of, the name may be interpreted as a *generic
function* which will be compiled appropriately to the target database::
>>> print(func.current_timestamp())
CURRENT_TIMESTAMP
To call functions which are present in dot-separated packages,
specify them in the same manner::
>>> print(func.stats.yield_curve(5, 10))
stats.yield_curve(:yield_curve_1, :yield_curve_2)
SQLAlchemy can be made aware of the return type of functions to enable
type-specific lexical and result-based behavior. For example, to ensure
that a string-based function returns a Unicode value and is similarly
treated as a string in expressions, specify
:class:`~sqlalchemy.types.Unicode` as the type:
>>> print(func.my_string(u'hi', type_=Unicode) + ' ' +
... func.my_string(u'there', type_=Unicode))
my_string(:my_string_1) || :my_string_2 || my_string(:my_string_3)
The object returned by a :data:`.func` call is usually an instance of
:class:`.Function`.
This object meets the "column" interface, including comparison and labeling
functions. The object can also be passed the :meth:`~.Connectable.execute`
method of a :class:`.Connection` or :class:`.Engine`, where it will be
wrapped inside of a SELECT statement first::
print(connection.execute(func.current_timestamp()).scalar())
In a few exception cases, the :data:`.func` accessor
will redirect a name to a built-in expression such as :func:`.cast`
or :func:`.extract`, as these names have well-known meaning
but are not exactly the same as "functions" from a SQLAlchemy
perspective.
.. versionadded:: 0.8 :data:`.func` can return non-function expression
constructs for common quasi-functional names like :func:`.cast`
and :func:`.extract`.
Functions which are interpreted as "generic" functions know how to
calculate their return type automatically. For a listing of known generic
functions, see :ref:`generic_functions`.
.. note::
The :data:`.func` construct has only limited support for calling
standalone "stored procedures", especially those with special
parameterization concerns.
See the section :ref:`stored_procedures` for details on how to use
the DBAPI-level ``callproc()`` method for fully traditional stored
procedures.
"""
modifier = _FunctionGenerator(group=False)
class Function(FunctionElement):
"""Describe a named SQL function.
See the superclass :class:`.FunctionElement` for a description
of public methods.
.. seealso::
:data:`.func` - namespace which produces registered or ad-hoc
:class:`.Function` instances.
:class:`.GenericFunction` - allows creation of registered function
types.
"""
__visit_name__ = 'function'
def __init__(self, name, *clauses, **kw):
"""Construct a :class:`.Function`.
The :data:`.func` construct is normally used to construct
new :class:`.Function` instances.
"""
self.packagenames = kw.pop('packagenames', None) or []
self.name = name
self._bind = kw.get('bind', None)
self.type = sqltypes.to_instance(kw.get('type_', None))
FunctionElement.__init__(self, *clauses, **kw)
def _bind_param(self, operator, obj, type_=None):
return BindParameter(self.name, obj,
_compared_to_operator=operator,
_compared_to_type=self.type,
type_=type_,
unique=True)
class _GenericMeta(VisitableType):
def __call__(self, *args, **kwargs):
args = [_literal_as_binds(c) for c in args]
return type.__call__(self, *args, **kwargs)
def __init__(cls, clsname, bases, clsdict):
if annotation.Annotated not in cls.__mro__:
cls.name = name = clsdict.get('name', clsname)
cls.identifier = identifier = clsdict.get('identifier', name)
package = clsdict.pop('package', '_default')
# legacy
if '__return_type__' in clsdict:
cls.type = clsdict['__return_type__']
register_function(identifier, cls, package)
super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
class GenericFunction(Function):
__metaclass__ = _GenericMeta
def __init__(self, type_=None, args=(), **kwargs):
class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
"""Define a 'generic' function.
A generic function is a pre-established :class:`.Function`
class that is instantiated automatically when called
by name from the :data:`.func` attribute. Note that
calling any name from :data:`.func` has the effect that
a new :class:`.Function` instance is created automatically,
given that name. The primary use case for defining
a :class:`.GenericFunction` class is so that a function
of a particular name may be given a fixed return type.
It can also include custom argument parsing schemes as well
as additional methods.
Subclasses of :class:`.GenericFunction` are automatically
registered under the name of the class. For
example, a user-defined function ``as_utc()`` would
be available immediately::
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy.types import DateTime
class as_utc(GenericFunction):
type = DateTime
print select([func.as_utc()])
User-defined generic functions can be organized into
packages by specifying the "package" attribute when defining
:class:`.GenericFunction`. Third party libraries
containing many functions may want to use this in order
to avoid name conflicts with other systems. For example,
if our ``as_utc()`` function were part of a package
"time"::
class as_utc(GenericFunction):
type = DateTime
package = "time"
The above function would be available from :data:`.func`
using the package name ``time``::
print select([func.time.as_utc()])
A final option is to allow the function to be accessed
from one name in :data:`.func` but to render as a different name.
The ``identifier`` attribute will override the name used to
access the function as loaded from :data:`.func`, but will retain
the usage of ``name`` as the rendered name::
class GeoBuffer(GenericFunction):
type = Geometry
package = "geo"
name = "ST_Buffer"
identifier = "buffer"
The above function will render as follows::
>>> print func.geo.buffer()
ST_Buffer()
.. versionadded:: 0.8 :class:`.GenericFunction` now supports
automatic registration of new functions as well as package
and custom naming support.
.. versionchanged:: 0.8 The attribute name ``type`` is used
to specify the function's return type at the class level.
Previously, the name ``__return_type__`` was used. This
name is still recognized for backwards-compatibility.
"""
coerce_arguments = True
def __init__(self, *args, **kwargs):
parsed_args = kwargs.pop('_parsed_args', None)
if parsed_args is None:
parsed_args = [_literal_as_binds(c, self.name) for c in args]
self.packagenames = []
self.name = self.__class__.__name__
self._bind = kwargs.get('bind', None)
self.clause_expr = ClauseList(
operator=operators.comma_op,
group_contents=True, *args).self_group()
operator=operators.comma_op,
group_contents=True, *parsed_args).self_group()
self.type = sqltypes.to_instance(
type_ or getattr(self, '__return_type__', None))
kwargs.pop("type_", None) or getattr(self, 'type', None))
register_function("cast", Cast)
register_function("extract", Extract)
class next_value(GenericFunction):
"""Represent the 'next value', given a :class:`.Sequence`
as its single argument.
Compiles into the appropriate function on each backend,
or will raise NotImplementedError if used on a backend
that does not provide support for sequences.
"""
type = sqltypes.Integer()
name = "next_value"
def __init__(self, seq, **kw):
assert isinstance(seq, schema.Sequence), \
"next_value() accepts a Sequence object as input."
self._bind = kw.get('bind', None)
self.sequence = seq
@property
def _from_objects(self):
return []
class AnsiFunction(GenericFunction):
def __init__(self, **kwargs):
GenericFunction.__init__(self, **kwargs)
class ReturnTypeFromArgs(GenericFunction):
"""Define a function whose return type is the same as its arguments."""
def __init__(self, *args, **kwargs):
args = [_literal_as_binds(c, self.name) for c in args]
kwargs.setdefault('type_', _type_from_args(args))
GenericFunction.__init__(self, args=args, **kwargs)
kwargs['_parsed_args'] = args
super(ReturnTypeFromArgs, self).__init__(*args, **kwargs)
class coalesce(ReturnTypeFromArgs):
pass
class max(ReturnTypeFromArgs):
pass
class min(ReturnTypeFromArgs):
pass
class sum(ReturnTypeFromArgs):
pass
class now(GenericFunction):
__return_type__ = sqltypes.DateTime
type = sqltypes.DateTime
class concat(GenericFunction):
__return_type__ = sqltypes.String
def __init__(self, *args, **kwargs):
GenericFunction.__init__(self, args=args, **kwargs)
type = sqltypes.String
class char_length(GenericFunction):
__return_type__ = sqltypes.Integer
type = sqltypes.Integer
def __init__(self, arg, **kwargs):
GenericFunction.__init__(self, args=[arg], **kwargs)
GenericFunction.__init__(self, arg, **kwargs)
class random(GenericFunction):
def __init__(self, *args, **kwargs):
kwargs.setdefault('type_', None)
GenericFunction.__init__(self, args=args, **kwargs)
pass
class count(GenericFunction):
"""The ANSI COUNT aggregate function. With no arguments, emits COUNT \*."""
r"""The ANSI COUNT aggregate function. With no arguments,
emits COUNT \*.
__return_type__ = sqltypes.Integer
"""
type = sqltypes.Integer
def __init__(self, expression=None, **kwargs):
if expression is None:
expression = text('*')
GenericFunction.__init__(self, args=(expression,), **kwargs)
expression = literal_column('*')
super(count, self).__init__(expression, **kwargs)
class current_date(AnsiFunction):
__return_type__ = sqltypes.Date
type = sqltypes.Date
class current_time(AnsiFunction):
__return_type__ = sqltypes.Time
type = sqltypes.Time
class current_timestamp(AnsiFunction):
__return_type__ = sqltypes.DateTime
type = sqltypes.DateTime
class current_user(AnsiFunction):
__return_type__ = sqltypes.String
type = sqltypes.String
class localtime(AnsiFunction):
__return_type__ = sqltypes.DateTime
type = sqltypes.DateTime
class localtimestamp(AnsiFunction):
__return_type__ = sqltypes.DateTime
type = sqltypes.DateTime
class session_user(AnsiFunction):
__return_type__ = sqltypes.String
type = sqltypes.String
class sysdate(AnsiFunction):
__return_type__ = sqltypes.DateTime
type = sqltypes.DateTime
class user(AnsiFunction):
__return_type__ = sqltypes.String
type = sqltypes.String
class array_agg(GenericFunction):
"""support for the ARRAY_AGG function.
The ``func.array_agg(expr)`` construct returns an expression of
type :class:`.types.ARRAY`.
e.g.::
stmt = select([func.array_agg(table.c.values)[2:5]])
.. versionadded:: 1.1
.. seealso::
:func:`.postgresql.array_agg` - PostgreSQL-specific version that
returns :class:`.postgresql.ARRAY`, which has PG-specific operators added.
"""
type = sqltypes.ARRAY
def __init__(self, *args, **kwargs):
args = [_literal_as_binds(c) for c in args]
kwargs.setdefault('type_', self.type(_type_from_args(args)))
kwargs['_parsed_args'] = args
super(array_agg, self).__init__(*args, **kwargs)
class OrderedSetAgg(GenericFunction):
"""Define a function where the return type is based on the sort
expression type as defined by the expression passed to the
:meth:`.FunctionElement.within_group` method."""
array_for_multi_clause = False
def within_group_type(self, within_group):
func_clauses = self.clause_expr.element
order_by = sqlutil.unwrap_order_by(within_group.order_by)
if self.array_for_multi_clause and len(func_clauses.clauses) > 1:
return sqltypes.ARRAY(order_by[0].type)
else:
return order_by[0].type
class mode(OrderedSetAgg):
"""implement the ``mode`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is the same as the sort expression.
.. versionadded:: 1.1
"""
class percentile_cont(OrderedSetAgg):
"""implement the ``percentile_cont`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is the same as the sort expression,
or if the arguments are an array, an :class:`.types.ARRAY` of the sort
expression's type.
.. versionadded:: 1.1
"""
array_for_multi_clause = True
class percentile_disc(OrderedSetAgg):
"""implement the ``percentile_disc`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is the same as the sort expression,
or if the arguments are an array, an :class:`.types.ARRAY` of the sort
expression's type.
.. versionadded:: 1.1
"""
array_for_multi_clause = True
class rank(GenericFunction):
"""Implement the ``rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is :class:`.Integer`.
.. versionadded:: 1.1
"""
type = sqltypes.Integer()
class dense_rank(GenericFunction):
"""Implement the ``dense_rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is :class:`.Integer`.
.. versionadded:: 1.1
"""
type = sqltypes.Integer()
class percent_rank(GenericFunction):
"""Implement the ``percent_rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is :class:`.Numeric`.
.. versionadded:: 1.1
"""
type = sqltypes.Numeric()
class cume_dist(GenericFunction):
"""Implement the ``cume_dist`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is :class:`.Numeric`.
.. versionadded:: 1.1
"""
type = sqltypes.Numeric()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,90 +1,137 @@
# sql/visitors.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Visitor/traversal interface and library functions.
SQLAlchemy schema and expression constructs rely on a Python-centric
version of the classic "visitor" pattern as the primary way in which
they apply functionality. The most common use of this pattern
is statement compilation, where individual expression classes match
up to rendering methods that produce a string result. Beyond this,
the visitor system is also used to inspect expressions for various
information and patterns, as well as for usage in
they apply functionality. The most common use of this pattern
is statement compilation, where individual expression classes match
up to rendering methods that produce a string result. Beyond this,
the visitor system is also used to inspect expressions for various
information and patterns, as well as for usage in
some kinds of expression transformation. Other kinds of transformation
use a non-visitor traversal system.
For many examples of how the visit system is used, see the
For many examples of how the visit system is used, see the
sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules.
For an introduction to clause adaption, see
http://techspot.zzzeek.org/?p=19 .
http://techspot.zzzeek.org/2008/01/23/expression-transformations/
"""
from collections import deque
import re
from sqlalchemy import util
from .. import util
import operator
from .. import exc
__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
'iterate_depthfirst', 'traverse_using', 'traverse',
'traverse_depthfirst',
'cloned_traverse', 'replacement_traverse']
__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
'iterate_depthfirst', 'traverse_using', 'traverse',
'cloned_traverse', 'replacement_traverse']
class VisitableType(type):
"""Metaclass which checks for a `__visit_name__` attribute and
applies `_compiler_dispatch` method to classes.
"""Metaclass which assigns a `_compiler_dispatch` method to classes
having a `__visit_name__` attribute.
The _compiler_dispatch attribute becomes an instance method which
looks approximately like the following::
def _compiler_dispatch (self, visitor, **kw):
'''Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.'''
visit_attr = 'visit_%s' % self.__visit_name__
return getattr(visitor, visit_attr)(self, **kw)
Classes having no __visit_name__ attribute will remain unaffected.
"""
def __init__(cls, clsname, bases, clsdict):
if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'):
super(VisitableType, cls).__init__(clsname, bases, clsdict)
return
# set up an optimized visit dispatch function
# for use by the compiler
visit_name = cls.__visit_name__
if isinstance(visit_name, str):
getter = operator.attrgetter("visit_%s" % visit_name)
def _compiler_dispatch(self, visitor, **kw):
return getter(visitor)(self, **kw)
else:
def _compiler_dispatch(self, visitor, **kw):
return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
cls._compiler_dispatch = _compiler_dispatch
if clsname != 'Visitable' and \
hasattr(cls, '__visit_name__'):
_generate_dispatch(cls)
super(VisitableType, cls).__init__(clsname, bases, clsdict)
class Visitable(object):
def _generate_dispatch(cls):
"""Return an optimized visit dispatch function for the cls
for use by the compiler.
"""
if '__visit_name__' in cls.__dict__:
visit_name = cls.__visit_name__
if isinstance(visit_name, str):
# There is an optimization opportunity here because the
# the string name of the class's __visit_name__ is known at
# this early stage (import time) so it can be pre-constructed.
getter = operator.attrgetter("visit_%s" % visit_name)
def _compiler_dispatch(self, visitor, **kw):
try:
meth = getter(visitor)
except AttributeError:
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
else:
# The optimization opportunity is lost for this case because the
# __visit_name__ is not yet a string. As a result, the visit
# string has to be recalculated with each compilation.
def _compiler_dispatch(self, visitor, **kw):
visit_attr = 'visit_%s' % self.__visit_name__
try:
meth = getattr(visitor, visit_attr)
except AttributeError:
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
_compiler_dispatch.__doc__ = \
"""Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.
"""
cls._compiler_dispatch = _compiler_dispatch
class Visitable(util.with_metaclass(VisitableType, object)):
"""Base class for visitable objects, applies the
``VisitableType`` metaclass.
"""
__metaclass__ = VisitableType
class ClauseVisitor(object):
"""Base class for visitor objects which can traverse using
"""Base class for visitor objects which can traverse using
the traverse() function.
"""
__traverse_options__ = {}
def traverse_single(self, obj):
def traverse_single(self, obj, **kw):
for v in self._visitor_iterator:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth:
return meth(obj)
def iterate(self, obj):
"""traverse the given expression structure, returning an iterator of all elements."""
return meth(obj, **kw)
def iterate(self, obj):
"""traverse the given expression structure, returning an iterator
of all elements.
"""
return iterate(obj, self.__traverse_options__)
def traverse(self, obj):
"""traverse and visit the given expression structure."""
return traverse(obj, self.__traverse_options__, self._visitor_dict)
@util.memoized_property
def _visitor_dict(self):
visitors = {}
@ -93,11 +140,11 @@ class ClauseVisitor(object):
if name.startswith('visit_'):
visitors[name[6:]] = getattr(self, name)
return visitors
@property
def _visitor_iterator(self):
"""iterate through this visitor and each 'chained' visitor."""
v = self
while v:
yield v
@ -105,41 +152,46 @@ class ClauseVisitor(object):
def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
the chained visitor will receive all visit events after this one.
"""
tail = list(self._visitor_iterator)[-1]
tail._next = visitor
return self
class CloningVisitor(ClauseVisitor):
"""Base class for visitor objects which can traverse using
"""Base class for visitor objects which can traverse using
the cloned_traverse() function.
"""
def copy_and_process(self, list_):
"""Apply cloned traversal to the given list of elements, and return the new list."""
"""Apply cloned traversal to the given list of elements, and return
the new list.
"""
return [self.traverse(x) for x in list_]
def traverse(self, obj):
"""traverse and visit the given expression structure."""
return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict)
return cloned_traverse(
obj, self.__traverse_options__, self._visitor_dict)
class ReplacingCloningVisitor(CloningVisitor):
"""Base class for visitor objects which can traverse using
"""Base class for visitor objects which can traverse using
the replacement_traverse() function.
"""
def replace(self, elem):
"""receive pre-copied elements during a cloning traversal.
If the method returns a new element, the element is used
instead of creating a simple copy of the element. Traversal
If the method returns a new element, the element is used
instead of creating a simple copy of the element. Traversal
will halt on the newly returned element if it is re-encountered.
"""
return None
@ -154,25 +206,39 @@ class ReplacingCloningVisitor(CloningVisitor):
return e
return replacement_traverse(obj, self.__traverse_options__, replace)
def iterate(obj, opts):
"""traverse the given expression structure, returning an iterator.
traversal is configured to be breadth-first.
"""
# fasttrack for atomic elements like columns
children = obj.get_children(**opts)
if not children:
return [obj]
traversal = deque()
stack = deque([obj])
while stack:
t = stack.popleft()
yield t
traversal.append(t)
for c in t.get_children(**opts):
stack.append(c)
return iter(traversal)
def iterate_depthfirst(obj, opts):
"""traverse the given expression structure, returning an iterator.
traversal is configured to be depth-first.
"""
# fasttrack for atomic elements like columns
children = obj.get_children(**opts)
if not children:
return [obj]
stack = deque([obj])
traversal = deque()
while stack:
@ -182,75 +248,81 @@ def iterate_depthfirst(obj, opts):
stack.append(c)
return iter(traversal)
def traverse_using(iterator, obj, visitors):
"""visit the given expression structure using the given iterator of objects."""
def traverse_using(iterator, obj, visitors):
"""visit the given expression structure using the given iterator of
objects.
"""
for target in iterator:
meth = visitors.get(target.__visit_name__, None)
if meth:
meth(target)
return obj
def traverse(obj, opts, visitors):
"""traverse and visit the given expression structure using the default iterator."""
def traverse(obj, opts, visitors):
"""traverse and visit the given expression structure using the default
iterator.
"""
return traverse_using(iterate(obj, opts), obj, visitors)
def traverse_depthfirst(obj, opts, visitors):
"""traverse and visit the given expression structure using the depth-first iterator."""
def traverse_depthfirst(obj, opts, visitors):
"""traverse and visit the given expression structure using the
depth-first iterator.
"""
return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
def cloned_traverse(obj, opts, visitors):
"""clone the given expression structure, allowing modifications by visitors."""
cloned = util.column_dict()
"""clone the given expression structure, allowing
modifications by visitors."""
def clone(element):
if element not in cloned:
cloned[element] = element._clone()
return cloned[element]
cloned = {}
stop_on = set(opts.get('stop_on', []))
obj = clone(obj)
stack = [obj]
def clone(elem):
if elem in stop_on:
return elem
else:
if id(elem) not in cloned:
cloned[id(elem)] = newelem = elem._clone()
newelem._copy_internals(clone=clone)
meth = visitors.get(newelem.__visit_name__, None)
if meth:
meth(newelem)
return cloned[id(elem)]
while stack:
t = stack.pop()
if t in cloned:
continue
t._copy_internals(clone=clone)
meth = visitors.get(t.__visit_name__, None)
if meth:
meth(t)
for c in t.get_children(**opts):
stack.append(c)
if obj is not None:
obj = clone(obj)
return obj
def replacement_traverse(obj, opts, replace):
"""clone the given expression structure, allowing element replacement by a given replacement function."""
cloned = util.column_dict()
stop_on = util.column_set(opts.get('stop_on', []))
"""clone the given expression structure, allowing element
replacement by a given replacement function."""
def clone(element):
newelem = replace(element)
if newelem is not None:
stop_on.add(newelem)
return newelem
cloned = {}
stop_on = set([id(x) for x in opts.get('stop_on', [])])
if element not in cloned:
cloned[element] = element._clone()
return cloned[element]
def clone(elem, **kw):
if id(elem) in stop_on or \
'no_replacement_traverse' in elem._annotations:
return elem
else:
newelem = replace(elem)
if newelem is not None:
stop_on.add(id(newelem))
return newelem
else:
if elem not in cloned:
cloned[elem] = newelem = elem._clone()
newelem._copy_internals(clone=clone, **kw)
return cloned[elem]
obj = clone(obj)
stack = [obj]
while stack:
t = stack.pop()
if t in stop_on:
continue
t._copy_internals(clone=clone)
for c in t.get_children(**opts):
stack.append(c)
if obj is not None:
obj = clone(obj, **opts)
return obj

File diff suppressed because it is too large Load Diff