Updated SqlAlchemy

This commit is contained in:
Christoffer Viken 2017-04-15 16:27:12 +00:00
parent 2c790e1fe1
commit e3267d4bda
59 changed files with 30236 additions and 26049 deletions

View File

@ -1,24 +1,23 @@
# __init__.py # sqlalchemy/__init__.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
import inspect
import sys
import sqlalchemy.exc as exceptions from .sql import (
sys.modules['sqlalchemy.exceptions'] = exceptions
from sqlalchemy.sql import (
alias, alias,
all_,
and_, and_,
any_,
asc, asc,
between, between,
bindparam, bindparam,
case, case,
cast, cast,
collate, collate,
column,
delete, delete,
desc, desc,
distinct, distinct,
@ -26,11 +25,14 @@ from sqlalchemy.sql import (
except_all, except_all,
exists, exists,
extract, extract,
false,
func, func,
funcfilter,
insert, insert,
intersect, intersect,
intersect_all, intersect_all,
join, join,
lateral,
literal, literal,
literal_column, literal_column,
modifier, modifier,
@ -39,16 +41,25 @@ from sqlalchemy.sql import (
or_, or_,
outerjoin, outerjoin,
outparam, outparam,
over,
select, select,
subquery, subquery,
table,
tablesample,
text, text,
true,
tuple_, tuple_,
type_coerce,
union, union,
union_all, union_all,
update, update,
within_group,
) )
from sqlalchemy.types import ( from .types import (
ARRAY,
BIGINT,
BINARY,
BLOB, BLOB,
BOOLEAN, BOOLEAN,
BigInteger, BigInteger,
@ -68,12 +79,14 @@ from sqlalchemy.types import (
INTEGER, INTEGER,
Integer, Integer,
Interval, Interval,
JSON,
LargeBinary, LargeBinary,
NCHAR, NCHAR,
NVARCHAR, NVARCHAR,
NUMERIC, NUMERIC,
Numeric, Numeric,
PickleType, PickleType,
REAL,
SMALLINT, SMALLINT,
SmallInteger, SmallInteger,
String, String,
@ -82,18 +95,19 @@ from sqlalchemy.types import (
TIMESTAMP, TIMESTAMP,
Text, Text,
Time, Time,
TypeDecorator,
Unicode, Unicode,
UnicodeText, UnicodeText,
VARBINARY,
VARCHAR, VARCHAR,
) )
from sqlalchemy.schema import ( from .schema import (
CheckConstraint, CheckConstraint,
Column, Column,
ColumnDefault, ColumnDefault,
Constraint, Constraint,
DDL,
DefaultClause, DefaultClause,
FetchedValue, FetchedValue,
ForeignKey, ForeignKey,
@ -106,14 +120,27 @@ from sqlalchemy.schema import (
Table, Table,
ThreadLocalMetaData, ThreadLocalMetaData,
UniqueConstraint, UniqueConstraint,
) DDL,
BLANK_SCHEMA
from sqlalchemy.engine import create_engine, engine_from_config )
__all__ = sorted(name for name, obj in locals().items() from .inspection import inspect
if not (name.startswith('_') or inspect.ismodule(obj))) from .engine import create_engine, engine_from_config
__version__ = '0.6beta3'
del inspect, sys __version__ = '1.1.9'
def __go(lcls):
global __all__
from . import events
from . import util as _sa_util
import inspect as _inspect
__all__ = sorted(name for name, obj in lcls.items()
if not (name.startswith('_') or _inspect.ismodule(obj)))
_sa_util.dependencies.resolve_all("sqlalchemy")
__go(locals())

View File

@ -1,6 +1,10 @@
# connectors/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
class Connector(object): class Connector(object):
pass pass

View File

@ -1,5 +1,12 @@
# connectors/mxodbc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
""" """
Provide an SQLALchemy connector for the eGenix mxODBC commercial Provide a SQLALchemy connector for the eGenix mxODBC commercial
Python adapter for ODBC. This is not a free product, but eGenix Python adapter for ODBC. This is not a free product, but eGenix
provides SQLAlchemy with a license for use in continuous integration provides SQLAlchemy with a license for use in continuous integration
testing. testing.
@ -15,21 +22,19 @@ For more info on mxODBC, see http://www.egenix.com/
import sys import sys
import re import re
import warnings import warnings
from decimal import Decimal
from sqlalchemy.connectors import Connector from . import Connector
from sqlalchemy import types as sqltypes
import sqlalchemy.processors as processors
class MxODBCConnector(Connector): class MxODBCConnector(Connector):
driver='mxodbc' driver = 'mxodbc'
supports_sane_multi_rowcount = False supports_sane_multi_rowcount = False
supports_unicode_statements = False supports_unicode_statements = True
supports_unicode_binds = False supports_unicode_binds = True
supports_native_decimal = True supports_native_decimal = True
@classmethod @classmethod
def dbapi(cls): def dbapi(cls):
# this classmethod will normally be replaced by an instance # this classmethod will normally be replaced by an instance
@ -44,7 +49,7 @@ class MxODBCConnector(Connector):
elif platform == 'darwin': elif platform == 'darwin':
from mx.ODBC import iODBC as module from mx.ODBC import iODBC as module
else: else:
raise ImportError, "Unrecognized platform for mxODBC import" raise ImportError("Unrecognized platform for mxODBC import")
return module return module
@classmethod @classmethod
@ -64,21 +69,21 @@ class MxODBCConnector(Connector):
conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT
conn.errorhandler = self._error_handler() conn.errorhandler = self._error_handler()
return connect return connect
def _error_handler(self): def _error_handler(self):
""" Return a handler that adjusts mxODBC's raised Warnings to """ Return a handler that adjusts mxODBC's raised Warnings to
emit Python standard warnings. emit Python standard warnings.
""" """
from mx.ODBC.Error import Warning as MxOdbcWarning from mx.ODBC.Error import Warning as MxOdbcWarning
def error_handler(connection, cursor, errorclass, errorvalue):
def error_handler(connection, cursor, errorclass, errorvalue):
if issubclass(errorclass, MxOdbcWarning): if issubclass(errorclass, MxOdbcWarning):
errorclass.__bases__ = (Warning,) errorclass.__bases__ = (Warning,)
warnings.warn(message=str(errorvalue), warnings.warn(message=str(errorvalue),
category=errorclass, category=errorclass,
stacklevel=2) stacklevel=2)
else: else:
raise errorclass, errorvalue raise errorclass(errorvalue)
return error_handler return error_handler
def create_connect_args(self, url): def create_connect_args(self, url):
@ -94,7 +99,7 @@ class MxODBCConnector(Connector):
The arg 'errorhandler' is not used by SQLAlchemy and will The arg 'errorhandler' is not used by SQLAlchemy and will
not be populated. not be populated.
""" """
opts = url.translate_connect_args(username='user') opts = url.translate_connect_args(username='user')
opts.update(url.query) opts.update(url.query)
@ -103,9 +108,9 @@ class MxODBCConnector(Connector):
opts.pop('database', None) opts.pop('database', None)
return (args,), opts return (args,), opts
def is_disconnect(self, e): def is_disconnect(self, e, connection, cursor):
# eGenix recommends checking connection.closed here, # TODO: eGenix recommends checking connection.closed here
# but how can we get a handle on the current connection? # Does that detect dropped connections ?
if isinstance(e, self.dbapi.ProgrammingError): if isinstance(e, self.dbapi.ProgrammingError):
return "connection already closed" in str(e) return "connection already closed" in str(e)
elif isinstance(e, self.dbapi.Error): elif isinstance(e, self.dbapi.Error):
@ -114,10 +119,11 @@ class MxODBCConnector(Connector):
return False return False
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
# eGenix suggests using conn.dbms_version instead of what we're doing here # eGenix suggests using conn.dbms_version instead
# of what we're doing here
dbapi_con = connection.connection dbapi_con = connection.connection
version = [] version = []
r = re.compile('[.\-]') r = re.compile(r'[.\-]')
# 18 == pyodbc.SQL_DBMS_VER # 18 == pyodbc.SQL_DBMS_VER
for n in r.split(dbapi_con.getinfo(18)[1]): for n in r.split(dbapi_con.getinfo(18)[1]):
try: try:
@ -126,21 +132,19 @@ class MxODBCConnector(Connector):
version.append(n) version.append(n)
return tuple(version) return tuple(version)
def do_execute(self, cursor, statement, parameters, context=None): def _get_direct(self, context):
if context: if context:
native_odbc_execute = context.execution_options.\ native_odbc_execute = context.execution_options.\
get('native_odbc_execute', 'auto') get('native_odbc_execute', 'auto')
if native_odbc_execute is True: # default to direct=True in all cases, is more generally
# user specified native_odbc_execute=True # compatible especially with SQL Server
cursor.execute(statement, parameters) return False if native_odbc_execute is True else True
elif native_odbc_execute is False:
# user specified native_odbc_execute=False
cursor.executedirect(statement, parameters)
elif context.is_crud:
# statement is UPDATE, DELETE, INSERT
cursor.execute(statement, parameters)
else:
# all other statements
cursor.executedirect(statement, parameters)
else: else:
cursor.executedirect(statement, parameters) return True
def do_executemany(self, cursor, statement, parameters, context=None):
cursor.executemany(
statement, parameters, direct=self._get_direct(context))
def do_execute(self, cursor, statement, parameters, context=None):
cursor.execute(statement, parameters, direct=self._get_direct(context))

View File

@ -1,29 +1,51 @@
from sqlalchemy.connectors import Connector # connectors/pyodbc.py
from sqlalchemy.util import asbool # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from . import Connector
from .. import util
import sys import sys
import re import re
import urllib
import decimal
class PyODBCConnector(Connector): class PyODBCConnector(Connector):
driver='pyodbc' driver = 'pyodbc'
supports_sane_multi_rowcount = False supports_sane_multi_rowcount = False
# PyODBC unicode is broken on UCS-4 builds
supports_unicode = sys.maxunicode == 65535 if util.py2k:
supports_unicode_statements = supports_unicode # PyODBC unicode is broken on UCS-4 builds
supports_unicode = sys.maxunicode == 65535
supports_unicode_statements = supports_unicode
supports_native_decimal = True supports_native_decimal = True
default_paramstyle = 'named' default_paramstyle = 'named'
# for non-DSN connections, this should # for non-DSN connections, this *may* be used to
# hold the desired driver name # hold the desired driver name
pyodbc_driver_name = None pyodbc_driver_name = None
# will be set to True after initialize() # will be set to True after initialize()
# if the freetds.so is detected # if the freetds.so is detected
freetds = False freetds = False
# will be set to the string version of
# the FreeTDS driver if freetds is detected
freetds_driver_version = None
# will be set to True after initialize()
# if the libessqlsrv.so is detected
easysoft = False
def __init__(self, supports_unicode_binds=None, **kw):
super(PyODBCConnector, self).__init__(**kw)
self._user_supports_unicode_binds = supports_unicode_binds
@classmethod @classmethod
def dbapi(cls): def dbapi(cls):
return __import__('pyodbc') return __import__('pyodbc')
@ -31,29 +53,53 @@ class PyODBCConnector(Connector):
def create_connect_args(self, url): def create_connect_args(self, url):
opts = url.translate_connect_args(username='user') opts = url.translate_connect_args(username='user')
opts.update(url.query) opts.update(url.query)
keys = opts keys = opts
query = url.query query = url.query
connect_args = {} connect_args = {}
for param in ('ansi', 'unicode_results', 'autocommit'): for param in ('ansi', 'unicode_results', 'autocommit'):
if param in keys: if param in keys:
connect_args[param] = asbool(keys.pop(param)) connect_args[param] = util.asbool(keys.pop(param))
if 'odbc_connect' in keys: if 'odbc_connect' in keys:
connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))] connectors = [util.unquote_plus(keys.pop('odbc_connect'))]
else: else:
dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys) def check_quote(token):
if ";" in str(token):
token = "'%s'" % token
return token
keys = dict(
(k, check_quote(v)) for k, v in keys.items()
)
dsn_connection = 'dsn' in keys or \
('host' in keys and 'database' not in keys)
if dsn_connection: if dsn_connection:
connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))] connectors = ['dsn=%s' % (keys.pop('host', '') or
keys.pop('dsn', ''))]
else: else:
port = '' port = ''
if 'port' in keys and not 'port' in query: if 'port' in keys and 'port' not in query:
port = ',%d' % int(keys.pop('port')) port = ',%d' % int(keys.pop('port'))
connectors = ["DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name), connectors = []
'Server=%s%s' % (keys.pop('host', ''), port), driver = keys.pop('driver', self.pyodbc_driver_name)
'Database=%s' % keys.pop('database', '') ] if driver is None:
util.warn(
"No driver name specified; "
"this is expected by PyODBC when using "
"DSN-less connections")
else:
connectors.append("DRIVER={%s}" % driver)
connectors.extend(
[
'Server=%s%s' % (keys.pop('host', ''), port),
'Database=%s' % keys.pop('database', '')
])
user = keys.pop("user", None) user = keys.pop("user", None)
if user: if user:
@ -62,20 +108,22 @@ class PyODBCConnector(Connector):
else: else:
connectors.append("Trusted_Connection=Yes") connectors.append("Trusted_Connection=Yes")
# if set to 'Yes', the ODBC layer will try to automagically convert # if set to 'Yes', the ODBC layer will try to automagically
# textual data from your database encoding to your client encoding # convert textual data from your database encoding to your
# This should obviously be set to 'No' if you query a cp1253 encoded # client encoding. This should obviously be set to 'No' if
# database from a latin1 client... # you query a cp1253 encoded database from a latin1 client...
if 'odbc_autotranslate' in keys: if 'odbc_autotranslate' in keys:
connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate")) connectors.append("AutoTranslate=%s" %
keys.pop("odbc_autotranslate"))
connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()]) connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()])
return [[";".join (connectors)], connect_args]
return [[";".join(connectors)], connect_args]
def is_disconnect(self, e):
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.ProgrammingError): if isinstance(e, self.dbapi.ProgrammingError):
return "The cursor's connection has been closed." in str(e) or \ return "The cursor's connection has been closed." in str(e) or \
'Attempt to use a closed connection.' in str(e) 'Attempt to use a closed connection.' in str(e)
elif isinstance(e, self.dbapi.Error): elif isinstance(e, self.dbapi.Error):
return '[08S01]' in str(e) return '[08S01]' in str(e)
else: else:
@ -84,27 +132,62 @@ class PyODBCConnector(Connector):
def initialize(self, connection): def initialize(self, connection):
# determine FreeTDS first. can't issue SQL easily # determine FreeTDS first. can't issue SQL easily
# without getting unicode_statements/binds set up. # without getting unicode_statements/binds set up.
pyodbc = self.dbapi pyodbc = self.dbapi
dbapi_con = connection.connection dbapi_con = connection.connection
self.freetds = bool(re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME))) _sql_driver_name = dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)
self.freetds = bool(re.match(r".*libtdsodbc.*\.so", _sql_driver_name
))
self.easysoft = bool(re.match(r".*libessqlsrv.*\.so", _sql_driver_name
))
if self.freetds:
self.freetds_driver_version = dbapi_con.getinfo(
pyodbc.SQL_DRIVER_VER)
self.supports_unicode_statements = (
not util.py2k or
(not self.freetds and not self.easysoft)
)
if self._user_supports_unicode_binds is not None:
self.supports_unicode_binds = self._user_supports_unicode_binds
elif util.py2k:
self.supports_unicode_binds = (
not self.freetds or self.freetds_driver_version >= '0.91'
) and not self.easysoft
else:
self.supports_unicode_binds = True
# the "Py2K only" part here is theoretical.
# have not tried pyodbc + python3.1 yet.
# Py2K
self.supports_unicode_statements = not self.freetds
self.supports_unicode_binds = not self.freetds
# end Py2K
# run other initialization which asks for user name, etc. # run other initialization which asks for user name, etc.
super(PyODBCConnector, self).initialize(connection) super(PyODBCConnector, self).initialize(connection)
def _dbapi_version(self):
if not self.dbapi:
return ()
return self._parse_dbapi_version(self.dbapi.version)
def _parse_dbapi_version(self, vers):
m = re.match(
r'(?:py.*-)?([\d\.]+)(?:-(\w+))?',
vers
)
if not m:
return ()
vers = tuple([int(x) for x in m.group(1).split(".")])
if m.group(2):
vers += (m.group(2),)
return vers
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
# NOTE: this function is not reliable, particularly when
# freetds is in use. Implement database-specific server version
# queries.
dbapi_con = connection.connection dbapi_con = connection.connection
version = [] version = []
r = re.compile('[.\-]') r = re.compile(r'[.\-]')
for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
try: try:
version.append(int(n)) version.append(int(n))

View File

@ -1,20 +1,28 @@
# connectors/zxJDBC.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sys import sys
from sqlalchemy.connectors import Connector from . import Connector
class ZxJDBCConnector(Connector): class ZxJDBCConnector(Connector):
driver = 'zxjdbc' driver = 'zxjdbc'
supports_sane_rowcount = False supports_sane_rowcount = False
supports_sane_multi_rowcount = False supports_sane_multi_rowcount = False
supports_unicode_binds = True supports_unicode_binds = True
supports_unicode_statements = sys.version > '2.5.0+' supports_unicode_statements = sys.version > '2.5.0+'
description_encoding = None description_encoding = None
default_paramstyle = 'qmark' default_paramstyle = 'qmark'
jdbc_db_name = None jdbc_db_name = None
jdbc_driver_name = None jdbc_driver_name = None
@classmethod @classmethod
def dbapi(cls): def dbapi(cls):
from com.ziclix.python.sql import zxJDBC from com.ziclix.python.sql import zxJDBC
@ -23,20 +31,24 @@ class ZxJDBCConnector(Connector):
def _driver_kwargs(self): def _driver_kwargs(self):
"""Return kw arg dict to be sent to connect().""" """Return kw arg dict to be sent to connect()."""
return {} return {}
def _create_jdbc_url(self, url): def _create_jdbc_url(self, url):
"""Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`""" """Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`"""
return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host, return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host,
url.port is not None and ':%s' % url.port or '', url.port is not None
and ':%s' % url.port or '',
url.database) url.database)
def create_connect_args(self, url): def create_connect_args(self, url):
opts = self._driver_kwargs() opts = self._driver_kwargs()
opts.update(url.query) opts.update(url.query)
return [[self._create_jdbc_url(url), url.username, url.password, self.jdbc_driver_name], return [
opts] [self._create_jdbc_url(url),
url.username, url.password,
self.jdbc_driver_name],
opts]
def is_disconnect(self, e): def is_disconnect(self, e, connection, cursor):
if not isinstance(e, self.dbapi.ProgrammingError): if not isinstance(e, self.dbapi.ProgrammingError):
return False return False
e = str(e) e = str(e)

View File

@ -1,12 +1,56 @@
# dialects/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
__all__ = ( __all__ = (
# 'access', 'firebird',
# 'firebird', 'mssql',
# 'informix',
# 'maxdb',
# 'mssql',
'mysql', 'mysql',
'oracle', 'oracle',
'postgresql', 'postgresql',
'sqlite', 'sqlite',
# 'sybase', 'sybase',
) )
from .. import util
_translates = {'postgres': 'postgresql'}
def _auto_fn(name):
"""default dialect importer.
plugs into the :class:`.PluginLoader`
as a first-hit system.
"""
if "." in name:
dialect, driver = name.split(".")
else:
dialect = name
driver = "base"
if dialect in _translates:
translated = _translates[dialect]
util.warn_deprecated(
"The '%s' dialect name has been "
"renamed to '%s'" % (dialect, translated)
)
dialect = translated
try:
module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
except ImportError:
return None
module = getattr(module, dialect)
if hasattr(module, driver):
module = getattr(module, driver)
return lambda: module.dialect
else:
return None
registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)
plugins = util.PluginLoader("sqlalchemy.plugins")

View File

@ -1,14 +1,36 @@
from sqlalchemy.dialects.postgresql import base, psycopg2, pg8000, pypostgresql, zxjdbc # postgresql/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from . import base, psycopg2, pg8000, pypostgresql, pygresql, \
zxjdbc, psycopg2cffi
base.dialect = psycopg2.dialect base.dialect = psycopg2.dialect
from sqlalchemy.dialects.postgresql.base import \ from .base import \
INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, INET, \ INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \
CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME,\ INET, CIDR, UUID, BIT, MACADDR, OID, DOUBLE_PRECISION, TIMESTAMP, TIME, \
DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect DATE, BYTEA, BOOLEAN, INTERVAL, ENUM, dialect, TSVECTOR, DropEnumType, \
CreateEnumType
from .hstore import HSTORE, hstore
from .json import JSON, JSONB
from .array import array, ARRAY, Any, All
from .ext import aggregate_order_by, ExcludeConstraint, array_agg
from .dml import insert, Insert
from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \
TSTZRANGE
__all__ = ( __all__ = (
'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', 'FLOAT', 'REAL', 'INET', 'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC',
'CIDR', 'UUID', 'BIT', 'MACADDR', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR', 'OID',
'DATE', 'BYTEA', 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect' 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN',
'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'array', 'HSTORE',
'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE',
'TSRANGE', 'TSTZRANGE', 'json', 'JSON', 'JSONB', 'Any', 'All',
'DropEnumType', 'CreateEnumType', 'ExcludeConstraint',
'aggregate_order_by', 'array_agg', 'insert', 'Insert'
) )

File diff suppressed because it is too large Load Diff

View File

@ -1,63 +1,130 @@
"""Support for the PostgreSQL database via the pg8000 driver. # postgresql/pg8000.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
Connecting """
---------- .. dialect:: postgresql+pg8000
:name: pg8000
:dbapi: pg8000
:connectstring: \
postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]
:url: https://pythonhosted.org/pg8000/
URLs are of the form
`postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]`. .. _pg8000_unicode:
Unicode Unicode
------- -------
pg8000 requires that the postgresql client encoding be configured in the postgresql.conf file pg8000 will encode / decode string values between it and the server using the
in order to use encodings other than ascii. Set this value to the same value as PostgreSQL ``client_encoding`` parameter; by default this is the value in
the "encoding" parameter on create_engine(), usually "utf-8". the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
Typically, this can be changed to ``utf-8``, as a more useful default::
Interval #client_encoding = sql_ascii # actually, defaults to database
-------- # encoding
client_encoding = utf8
The ``client_encoding`` can be overridden for a session by executing the SQL:
SET CLIENT_ENCODING TO 'utf8';
SQLAlchemy will execute this SQL on all new connections based on the value
passed to :func:`.create_engine` using the ``client_encoding`` parameter::
engine = create_engine(
"postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
.. _pg8000_isolation_level:
pg8000 Transaction Isolation Level
-------------------------------------
The pg8000 dialect offers the same isolation level settings as that
of the :ref:`psycopg2 <psycopg2_isolation_level>` dialect:
* ``READ COMMITTED``
* ``READ UNCOMMITTED``
* ``REPEATABLE READ``
* ``SERIALIZABLE``
* ``AUTOCOMMIT``
.. versionadded:: 0.9.5 support for AUTOCOMMIT isolation level when using
pg8000.
.. seealso::
:ref:`postgresql_isolation_level`
:ref:`psycopg2_isolation_level`
Passing data from/to the Interval type is not supported as of yet.
""" """
from ... import util, exc
import decimal import decimal
from ... import processors
from ... import types as sqltypes
from .base import (
PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext,
_DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES)
import re
from sqlalchemy.dialects.postgresql.json import JSON
from sqlalchemy.engine import default
from sqlalchemy import util, exc
from sqlalchemy import processors
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql.base import PGDialect, \
PGCompiler, PGIdentifierPreparer, PGExecutionContext
class _PGNumeric(sqltypes.Numeric): class _PGNumeric(sqltypes.Numeric):
def result_processor(self, dialect, coltype): def result_processor(self, dialect, coltype):
if self.asdecimal: if self.asdecimal:
if coltype in (700, 701): if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(decimal.Decimal) return processors.to_decimal_processor_factory(
elif coltype == 1700: decimal.Decimal, self._effective_decimal_return_scale)
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# pg8000 returns Decimal natively for 1700 # pg8000 returns Decimal natively for 1700
return None return None
else: else:
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
else: else:
if coltype in (700, 701): if coltype in _FLOAT_TYPES:
# pg8000 returns float natively for 701 # pg8000 returns float natively for 701
return None return None
elif coltype == 1700: elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
return processors.to_float return processors.to_float
else: else:
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
class _PGNumericNoBind(_PGNumeric):
def bind_processor(self, dialect):
return None
class _PGJSON(JSON):
def result_processor(self, dialect, coltype):
if dialect._dbapi_version > (1, 10, 1):
return None # Has native JSON
else:
return super(_PGJSON, self).result_processor(dialect, coltype)
class PGExecutionContext_pg8000(PGExecutionContext): class PGExecutionContext_pg8000(PGExecutionContext):
pass pass
class PGCompiler_pg8000(PGCompiler): class PGCompiler_pg8000(PGCompiler):
def visit_mod(self, binary, **kw): def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left) + " %% " + self.process(binary.right) return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)
def post_process_text(self, text): def post_process_text(self, text):
if '%%' in text: if '%%' in text:
util.warn("The SQLAlchemy postgresql dialect now automatically escapes '%' in text() " util.warn("The SQLAlchemy postgresql dialect "
"now automatically escapes '%' in text() "
"expressions to '%%'.") "expressions to '%%'.")
return text.replace('%', '%%') return text.replace('%', '%%')
@ -67,30 +134,52 @@ class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
value = value.replace(self.escape_quote, self.escape_to_quote) value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace('%', '%%') return value.replace('%', '%%')
class PGDialect_pg8000(PGDialect): class PGDialect_pg8000(PGDialect):
driver = 'pg8000' driver = 'pg8000'
supports_unicode_statements = True supports_unicode_statements = True
supports_unicode_binds = True supports_unicode_binds = True
default_paramstyle = 'format' default_paramstyle = 'format'
supports_sane_multi_rowcount = False supports_sane_multi_rowcount = True
execution_ctx_cls = PGExecutionContext_pg8000 execution_ctx_cls = PGExecutionContext_pg8000
statement_compiler = PGCompiler_pg8000 statement_compiler = PGCompiler_pg8000
preparer = PGIdentifierPreparer_pg8000 preparer = PGIdentifierPreparer_pg8000
description_encoding = 'use_encoding'
colspecs = util.update_copy( colspecs = util.update_copy(
PGDialect.colspecs, PGDialect.colspecs,
{ {
sqltypes.Numeric : _PGNumeric, sqltypes.Numeric: _PGNumericNoBind,
sqltypes.Float: _PGNumeric,
JSON: _PGJSON,
sqltypes.JSON: _PGJSON
} }
) )
def __init__(self, client_encoding=None, **kwargs):
PGDialect.__init__(self, **kwargs)
self.client_encoding = client_encoding
def initialize(self, connection):
self.supports_sane_multi_rowcount = self._dbapi_version >= (1, 9, 14)
super(PGDialect_pg8000, self).initialize(connection)
@util.memoized_property
def _dbapi_version(self):
if self.dbapi and hasattr(self.dbapi, '__version__'):
return tuple(
[
int(x) for x in re.findall(
r'(\d+)(?:[-\.]?|$)', self.dbapi.__version__)])
else:
return (99, 99, 99)
@classmethod @classmethod
def dbapi(cls): def dbapi(cls):
return __import__('pg8000').dbapi return __import__('pg8000')
def create_connect_args(self, url): def create_connect_args(self, url):
opts = url.translate_connect_args(username='user') opts = url.translate_connect_args(username='user')
@ -99,7 +188,78 @@ class PGDialect_pg8000(PGDialect):
opts.update(url.query) opts.update(url.query)
return ([], opts) return ([], opts)
def is_disconnect(self, e): def is_disconnect(self, e, connection, cursor):
return "connection is closed" in str(e) return "connection is closed" in str(e)
def set_isolation_level(self, connection, level):
level = level.replace('_', ' ')
# adjust for ConnectionFairy possibly being present
if hasattr(connection, 'connection'):
connection = connection.connection
if level == 'AUTOCOMMIT':
connection.autocommit = True
elif level in self._isolation_lookup:
connection.autocommit = False
cursor = connection.cursor()
cursor.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION "
"ISOLATION LEVEL %s" % level)
cursor.execute("COMMIT")
cursor.close()
else:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
"Valid isolation levels for %s are %s or AUTOCOMMIT" %
(level, self.name, ", ".join(self._isolation_lookup))
)
def set_client_encoding(self, connection, client_encoding):
# adjust for ConnectionFairy possibly being present
if hasattr(connection, 'connection'):
connection = connection.connection
cursor = connection.cursor()
cursor.execute("SET CLIENT_ENCODING TO '" + client_encoding + "'")
cursor.execute("COMMIT")
cursor.close()
def do_begin_twophase(self, connection, xid):
connection.connection.tpc_begin((0, xid, ''))
def do_prepare_twophase(self, connection, xid):
connection.connection.tpc_prepare()
def do_rollback_twophase(
self, connection, xid, is_prepared=True, recover=False):
connection.connection.tpc_rollback((0, xid, ''))
def do_commit_twophase(
self, connection, xid, is_prepared=True, recover=False):
connection.connection.tpc_commit((0, xid, ''))
def do_recover_twophase(self, connection):
return [row[1] for row in connection.connection.tpc_recover()]
def on_connect(self):
fns = []
if self.client_encoding is not None:
def on_connect(conn):
self.set_client_encoding(conn, self.client_encoding)
fns.append(on_connect)
if self.isolation_level is not None:
def on_connect(conn):
self.set_isolation_level(conn, self.isolation_level)
fns.append(on_connect)
if len(fns) > 0:
def on_connect(conn):
for fn in fns:
fn(conn)
return on_connect
else:
return None
dialect = PGDialect_pg8000 dialect = PGDialect_pg8000

View File

@ -1,74 +1,335 @@
"""Support for the PostgreSQL database via the psycopg2 driver. # postgresql/psycopg2.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
Driver """
------ .. dialect:: postgresql+psycopg2
:name: psycopg2
:dbapi: psycopg2
:connectstring: postgresql+psycopg2://user:password@host:port/dbname\
[?key=value&key=value...]
:url: http://pypi.python.org/pypi/psycopg2/
The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ . psycopg2 Connect Arguments
The dialect has several behaviors which are specifically tailored towards compatibility -----------------------------------
with this module.
Note that psycopg1 is **not** supported. psycopg2-specific keyword arguments which are accepted by
:func:`.create_engine()` are:
Connecting * ``server_side_cursors``: Enable the usage of "server side cursors" for SQL
---------- statements which support this feature. What this essentially means from a
psycopg2 point of view is that the cursor is created using a name, e.g.
``connection.cursor('some name')``, which has the effect that result rows
are not immediately pre-fetched and buffered after statement execution, but
are instead left on the server and only retrieved as needed. SQLAlchemy's
:class:`~sqlalchemy.engine.ResultProxy` uses special row-buffering
behavior when this feature is enabled, such that groups of 100 rows at a
time are fetched over the wire to reduce conversational overhead.
Note that the :paramref:`.Connection.execution_options.stream_results`
execution option is a more targeted
way of enabling this mode on a per-execution basis.
* ``use_native_unicode``: Enable the usage of Psycopg2 "native unicode" mode
per connection. True by default.
URLs are of the form `postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]`. .. seealso::
psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are: :ref:`psycopg2_disable_native_unicode`
* ``isolation_level``: This option, available for all PostgreSQL dialects,
includes the ``AUTOCOMMIT`` isolation level when using the psycopg2
dialect.
.. seealso::
:ref:`psycopg2_isolation_level`
* ``client_encoding``: sets the client encoding in a libpq-agnostic way,
using psycopg2's ``set_client_encoding()`` method.
.. seealso::
:ref:`psycopg2_unicode`
Unix Domain Connections
------------------------
psycopg2 supports connecting via Unix domain connections. When the ``host``
portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2,
which specifies Unix-domain communication rather than TCP/IP communication::
create_engine("postgresql+psycopg2://user:password@/dbname")
By default, the socket file used is to connect to a Unix-domain socket
in ``/tmp``, or whatever socket directory was specified when PostgreSQL
was built. This value can be overridden by passing a pathname to psycopg2,
using ``host`` as an additional keyword argument::
create_engine("postgresql+psycopg2://user:password@/dbname?\
host=/var/lib/postgresql")
See also:
`PQconnectdbParams <http://www.postgresql.org/docs/9.1/static/\
libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS>`_
.. _psycopg2_execution_options:
Per-Statement/Connection Execution Options
-------------------------------------------
The following DBAPI-specific options are respected when used with
:meth:`.Connection.execution_options`, :meth:`.Executable.execution_options`,
:meth:`.Query.execution_options`, in addition to those not specific to DBAPIs:
* ``isolation_level`` - Set the transaction isolation level for the lifespan of a
:class:`.Connection` (can only be set on a connection, not a statement
or query). See :ref:`psycopg2_isolation_level`.
* ``stream_results`` - Enable or disable usage of psycopg2 server side cursors -
this feature makes use of "named" cursors in combination with special
result handling methods so that result rows are not fully buffered.
If ``None`` or not set, the ``server_side_cursors`` option of the
:class:`.Engine` is used.
* ``max_row_buffer`` - when using ``stream_results``, an integer value that
specifies the maximum number of rows to buffer at a time. This is
interpreted by the :class:`.BufferedRowResultProxy`, and if omitted the
buffer will grow to ultimately store 1000 rows at a time.
.. versionadded:: 1.0.6
.. _psycopg2_unicode:
Unicode with Psycopg2
----------------------
By default, the psycopg2 driver uses the ``psycopg2.extensions.UNICODE``
extension, such that the DBAPI receives and returns all strings as Python
Unicode objects directly - SQLAlchemy passes these values through without
change. Psycopg2 here will encode/decode string values based on the
current "client encoding" setting; by default this is the value in
the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
Typically, this can be changed to ``utf8``, as a more useful default::
# postgresql.conf file
# client_encoding = sql_ascii # actually, defaults to database
# encoding
client_encoding = utf8
A second way to affect the client encoding is to set it within Psycopg2
locally. SQLAlchemy will call psycopg2's
:meth:`psycopg2:connection.set_client_encoding` method
on all new connections based on the value passed to
:func:`.create_engine` using the ``client_encoding`` parameter::
# set_client_encoding() setting;
# works for *all* PostgreSQL versions
engine = create_engine("postgresql://user:pass@host/dbname",
client_encoding='utf8')
This overrides the encoding specified in the PostgreSQL client configuration.
When using the parameter in this way, the psycopg2 driver emits
``SET client_encoding TO 'utf8'`` on the connection explicitly, and works
in all PostgreSQL versions.
Note that the ``client_encoding`` setting as passed to :func:`.create_engine`
is **not the same** as the more recently added ``client_encoding`` parameter
now supported by libpq directly. This is enabled when ``client_encoding``
is passed directly to ``psycopg2.connect()``, and from SQLAlchemy is passed
using the :paramref:`.create_engine.connect_args` parameter::
# libpq direct parameter setting;
# only works for PostgreSQL **9.1 and above**
engine = create_engine("postgresql://user:pass@host/dbname",
connect_args={'client_encoding': 'utf8'})
# using the query string is equivalent
engine = create_engine("postgresql://user:pass@host/dbname?client_encoding=utf8")
The above parameter was only added to libpq as of version 9.1 of PostgreSQL,
so using the previous method is better for cross-version support.
.. _psycopg2_disable_native_unicode:
Disabling Native Unicode
^^^^^^^^^^^^^^^^^^^^^^^^
SQLAlchemy can also be instructed to skip the usage of the psycopg2
``UNICODE`` extension and to instead utilize its own unicode encode/decode
services, which are normally reserved only for those DBAPIs that don't
fully support unicode directly. Passing ``use_native_unicode=False`` to
:func:`.create_engine` will disable usage of ``psycopg2.extensions.UNICODE``.
SQLAlchemy will instead encode data itself into Python bytestrings on the way
in and coerce from bytes on the way back,
using the value of the :func:`.create_engine` ``encoding`` parameter, which
defaults to ``utf-8``.
SQLAlchemy's own unicode encode/decode functionality is steadily becoming
obsolete as most DBAPIs now support unicode fully.
Bound Parameter Styles
----------------------
The default parameter style for the psycopg2 dialect is "pyformat", where
SQL is rendered using ``%(paramname)s`` style. This format has the limitation
that it does not accommodate the unusual case of parameter names that
actually contain percent or parenthesis symbols; as SQLAlchemy in many cases
generates bound parameter names based on the name of a column, the presence
of these characters in a column name can lead to problems.
There are two solutions to the issue of a :class:`.schema.Column` that contains
one of these characters in its name. One is to specify the
:paramref:`.schema.Column.key` for columns that have such names::
measurement = Table('measurement', metadata,
Column('Size (meters)', Integer, key='size_meters')
)
Above, an INSERT statement such as ``measurement.insert()`` will use
``size_meters`` as the parameter name, and a SQL expression such as
``measurement.c.size_meters > 10`` will derive the bound parameter name
from the ``size_meters`` key as well.
.. versionchanged:: 1.0.0 - SQL expressions will use :attr:`.Column.key`
as the source of naming when anonymous bound parameters are created
in SQL expressions; previously, this behavior only applied to
:meth:`.Table.insert` and :meth:`.Table.update` parameter names.
The other solution is to use a positional format; psycopg2 allows use of the
"format" paramstyle, which can be passed to
:paramref:`.create_engine.paramstyle`::
engine = create_engine(
'postgresql://scott:tiger@localhost:5432/test', paramstyle='format')
With the above engine, instead of a statement like::
INSERT INTO measurement ("Size (meters)") VALUES (%(Size (meters))s)
{'Size (meters)': 1}
we instead see::
INSERT INTO measurement ("Size (meters)") VALUES (%s)
(1, )
Where above, the dictionary style is converted into a tuple with positional
style.
* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
this feature. What this essentially means from a psycopg2 point of view is that the cursor is
created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
are not immediately pre-fetched and buffered after statement execution, but are instead left
on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows
at a time are fetched over the wire to reduce conversational overhead.
* *use_native_unicode* - Enable the usage of Psycopg2 "native unicode" mode per connection. True
by default.
* *isolation_level* - Sets the transaction isolation level for each transaction
within the engine. Valid isolation levels are `READ_COMMITTED`,
`READ_UNCOMMITTED`, `REPEATABLE_READ`, and `SERIALIZABLE`.
Transactions Transactions
------------ ------------
The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
.. _psycopg2_isolation_level:
Psycopg2 Transaction Isolation Level
-------------------------------------
As discussed in :ref:`postgresql_isolation_level`,
all PostgreSQL dialects support setting of transaction isolation level
both via the ``isolation_level`` parameter passed to :func:`.create_engine`,
as well as the ``isolation_level`` argument used by
:meth:`.Connection.execution_options`. When using the psycopg2 dialect, these
options make use of psycopg2's ``set_isolation_level()`` connection method,
rather than emitting a PostgreSQL directive; this is because psycopg2's
API-level setting is always emitted at the start of each transaction in any
case.
The psycopg2 dialect supports these constants for isolation level:
* ``READ COMMITTED``
* ``READ UNCOMMITTED``
* ``REPEATABLE READ``
* ``SERIALIZABLE``
* ``AUTOCOMMIT``
.. versionadded:: 0.8.2 support for AUTOCOMMIT isolation level when using
psycopg2.
.. seealso::
:ref:`postgresql_isolation_level`
:ref:`pg8000_isolation_level`
NOTICE logging NOTICE logging
--------------- ---------------
The psycopg2 dialect will log Postgresql NOTICE messages via the The psycopg2 dialect will log PostgreSQL NOTICE messages via the
``sqlalchemy.dialects.postgresql`` logger:: ``sqlalchemy.dialects.postgresql`` logger::
import logging import logging
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
.. _psycopg2_hstore::
Per-Statement Execution Options HSTORE type
------------------------------- ------------
The following per-statement execution options are respected: The ``psycopg2`` DBAPI includes an extension to natively handle marshalling of
the HSTORE type. The SQLAlchemy psycopg2 dialect will enable this extension
by default when psycopg2 version 2.4 or greater is used, and
it is detected that the target database has the HSTORE type set up for use.
In other words, when the dialect makes the first
connection, a sequence like the following is performed:
* *stream_results* - Enable or disable usage of server side cursors for the SELECT-statement. 1. Request the available HSTORE oids using
If *None* or not set, the *server_side_cursors* option of the connection is used. If ``psycopg2.extras.HstoreAdapter.get_oids()``.
auto-commit is enabled, the option is ignored. If this function returns a list of HSTORE identifiers, we then determine
that the ``HSTORE`` extension is present.
This function is **skipped** if the version of psycopg2 installed is
less than version 2.4.
2. If the ``use_native_hstore`` flag is at its default of ``True``, and
we've detected that ``HSTORE`` oids are available, the
``psycopg2.extensions.register_hstore()`` extension is invoked for all
connections.
The ``register_hstore()`` extension has the effect of **all Python
dictionaries being accepted as parameters regardless of the type of target
column in SQL**. The dictionaries are converted by this extension into a
textual HSTORE expression. If this behavior is not desired, disable the
use of the hstore extension by setting ``use_native_hstore`` to ``False`` as
follows::
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
use_native_hstore=False)
The ``HSTORE`` type is **still supported** when the
``psycopg2.extensions.register_hstore()`` extension is not used. It merely
means that the coercion between Python dictionaries and the HSTORE
string format, on both the parameter side and the result side, will take
place within SQLAlchemy's own marshalling logic, and not that of ``psycopg2``
which may be more performant.
""" """
from __future__ import absolute_import
import random
import re import re
import decimal
import logging import logging
from sqlalchemy import util from ... import util, exc
from sqlalchemy import processors import decimal
from sqlalchemy.engine import base, default from ... import processors
from sqlalchemy.sql import expression from ...engine import result as _result
from sqlalchemy.sql import operators as sql_operators from ...sql import expression
from sqlalchemy import types as sqltypes from ... import types as sqltypes
from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, \ from .base import PGDialect, PGCompiler, \
PGIdentifierPreparer, PGExecutionContext, \ PGIdentifierPreparer, PGExecutionContext, \
ENUM, ARRAY ENUM, _DECIMAL_TYPES, _FLOAT_TYPES,\
_INT_TYPES, UUID
from .hstore import HSTORE
from .json import JSON, JSONB
try:
from uuid import UUID as _python_UUID
except ImportError:
_python_UUID = None
logger = logging.getLogger('sqlalchemy.dialects.postgresql') logger = logging.getLogger('sqlalchemy.dialects.postgresql')
@ -80,82 +341,113 @@ class _PGNumeric(sqltypes.Numeric):
def result_processor(self, dialect, coltype): def result_processor(self, dialect, coltype):
if self.asdecimal: if self.asdecimal:
if coltype in (700, 701): if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(decimal.Decimal) return processors.to_decimal_processor_factory(
elif coltype == 1700: decimal.Decimal,
self._effective_decimal_return_scale)
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# pg8000 returns Decimal natively for 1700 # pg8000 returns Decimal natively for 1700
return None return None
else: else:
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
else: else:
if coltype in (700, 701): if coltype in _FLOAT_TYPES:
# pg8000 returns float natively for 701 # pg8000 returns float natively for 701
return None return None
elif coltype == 1700: elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
return processors.to_float return processors.to_float
else: else:
raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
class _PGEnum(ENUM): class _PGEnum(ENUM):
def __init__(self, *arg, **kw): def result_processor(self, dialect, coltype):
super(_PGEnum, self).__init__(*arg, **kw) if self.native_enum and util.py2k and self.convert_unicode is True:
if self.convert_unicode: # we can't easily use PG's extensions here because
self.convert_unicode = "force" # the OID is on the fly, and we need to give it a python
# function anyway - not really worth it.
self.convert_unicode = "force_nocheck"
return super(_PGEnum, self).result_processor(dialect, coltype)
class _PGArray(ARRAY):
def __init__(self, *arg, **kw):
super(_PGArray, self).__init__(*arg, **kw)
# FIXME: this check won't work for setups that
# have convert_unicode only on their create_engine().
if isinstance(self.item_type, sqltypes.String) and \
self.item_type.convert_unicode:
self.item_type.convert_unicode = "force"
# When we're handed literal SQL, ensure it's a SELECT-query. Since class _PGHStore(HSTORE):
# 8.3, combining cursors and "FOR UPDATE" has been fine. def bind_processor(self, dialect):
SERVER_SIDE_CURSOR_RE = re.compile( if dialect._has_native_hstore:
r'\s*SELECT', return None
re.I | re.UNICODE) else:
return super(_PGHStore, self).bind_processor(dialect)
def result_processor(self, dialect, coltype):
if dialect._has_native_hstore:
return None
else:
return super(_PGHStore, self).result_processor(dialect, coltype)
class _PGJSON(JSON):
def result_processor(self, dialect, coltype):
if dialect._has_native_json:
return None
else:
return super(_PGJSON, self).result_processor(dialect, coltype)
class _PGJSONB(JSONB):
def result_processor(self, dialect, coltype):
if dialect._has_native_jsonb:
return None
else:
return super(_PGJSONB, self).result_processor(dialect, coltype)
class _PGUUID(UUID):
def bind_processor(self, dialect):
if not self.as_uuid and dialect.use_native_uuid:
nonetype = type(None)
def process(value):
if value is not None:
value = _python_UUID(value)
return value
return process
def result_processor(self, dialect, coltype):
if not self.as_uuid and dialect.use_native_uuid:
def process(value):
if value is not None:
value = str(value)
return value
return process
_server_side_id = util.counter()
class PGExecutionContext_psycopg2(PGExecutionContext): class PGExecutionContext_psycopg2(PGExecutionContext):
def create_cursor(self): def create_server_side_cursor(self):
# TODO: coverage for server side cursors + select.for_update() # use server-side cursors:
# http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
if self.dialect.server_side_cursors: ident = "c_%s_%s" % (hex(id(self))[2:],
is_server_side = \ hex(_server_side_id())[2:])
self.execution_options.get('stream_results', True) and ( return self._dbapi_connection.cursor(ident)
(self.compiled and isinstance(self.compiled.statement, expression.Selectable) \
or \
(
(not self.compiled or
isinstance(self.compiled.statement, expression._TextClause))
and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
)
)
else:
is_server_side = self.execution_options.get('stream_results', False)
self.__is_server_side = is_server_side
if is_server_side:
# use server-side cursors:
# http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
return self._connection.connection.cursor(ident)
else:
return self._connection.connection.cursor()
def get_result_proxy(self): def get_result_proxy(self):
# TODO: ouch
if logger.isEnabledFor(logging.INFO): if logger.isEnabledFor(logging.INFO):
self._log_notices(self.cursor) self._log_notices(self.cursor)
if self.__is_server_side: if self._is_server_side:
return base.BufferedRowResultProxy(self) return _result.BufferedRowResultProxy(self)
else: else:
return base.ResultProxy(self) return _result.ResultProxy(self)
def _log_notices(self, cursor): def _log_notices(self, cursor):
for notice in cursor.connection.notices: for notice in cursor.connection.notices:
# NOTICE messages have a # NOTICE messages have a
# newline character at the end # newline character at the end
logger.info(notice.rstrip()) logger.info(notice.rstrip())
@ -163,9 +455,10 @@ class PGExecutionContext_psycopg2(PGExecutionContext):
class PGCompiler_psycopg2(PGCompiler): class PGCompiler_psycopg2(PGCompiler):
def visit_mod(self, binary, **kw): def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left) + " %% " + self.process(binary.right) return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)
def post_process_text(self, text): def post_process_text(self, text):
return text.replace('%', '%%') return text.replace('%', '%%')
@ -175,47 +468,191 @@ class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
value = value.replace(self.escape_quote, self.escape_to_quote) value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace('%', '%%') return value.replace('%', '%%')
class PGDialect_psycopg2(PGDialect): class PGDialect_psycopg2(PGDialect):
driver = 'psycopg2' driver = 'psycopg2'
supports_unicode_statements = False if util.py2k:
supports_unicode_statements = False
supports_server_side_cursors = True
default_paramstyle = 'pyformat' default_paramstyle = 'pyformat'
# set to true based on psycopg2 version
supports_sane_multi_rowcount = False supports_sane_multi_rowcount = False
execution_ctx_cls = PGExecutionContext_psycopg2 execution_ctx_cls = PGExecutionContext_psycopg2
statement_compiler = PGCompiler_psycopg2 statement_compiler = PGCompiler_psycopg2
preparer = PGIdentifierPreparer_psycopg2 preparer = PGIdentifierPreparer_psycopg2
psycopg2_version = (0, 0)
FEATURE_VERSION_MAP = dict(
native_json=(2, 5),
native_jsonb=(2, 5, 4),
sane_multi_rowcount=(2, 0, 9),
array_oid=(2, 4, 3),
hstore_adapter=(2, 4)
)
_has_native_hstore = False
_has_native_json = False
_has_native_jsonb = False
engine_config_types = PGDialect.engine_config_types.union([
('use_native_unicode', util.asbool),
])
colspecs = util.update_copy( colspecs = util.update_copy(
PGDialect.colspecs, PGDialect.colspecs,
{ {
sqltypes.Numeric : _PGNumeric, sqltypes.Numeric: _PGNumeric,
ENUM : _PGEnum, # needs force_unicode ENUM: _PGEnum, # needs force_unicode
sqltypes.Enum : _PGEnum, # needs force_unicode sqltypes.Enum: _PGEnum, # needs force_unicode
ARRAY : _PGArray, # needs force_unicode HSTORE: _PGHStore,
JSON: _PGJSON,
sqltypes.JSON: _PGJSON,
JSONB: _PGJSONB,
UUID: _PGUUID
} }
) )
def __init__(self, server_side_cursors=False, use_native_unicode=True, **kwargs): def __init__(self, server_side_cursors=False, use_native_unicode=True,
client_encoding=None,
use_native_hstore=True, use_native_uuid=True,
**kwargs):
PGDialect.__init__(self, **kwargs) PGDialect.__init__(self, **kwargs)
self.server_side_cursors = server_side_cursors self.server_side_cursors = server_side_cursors
self.use_native_unicode = use_native_unicode self.use_native_unicode = use_native_unicode
self.use_native_hstore = use_native_hstore
self.use_native_uuid = use_native_uuid
self.supports_unicode_binds = use_native_unicode self.supports_unicode_binds = use_native_unicode
self.client_encoding = client_encoding
if self.dbapi and hasattr(self.dbapi, '__version__'):
m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?',
self.dbapi.__version__)
if m:
self.psycopg2_version = tuple(
int(x)
for x in m.group(1, 2, 3)
if x is not None)
def initialize(self, connection):
super(PGDialect_psycopg2, self).initialize(connection)
self._has_native_hstore = self.use_native_hstore and \
self._hstore_oids(connection.connection) \
is not None
self._has_native_json = \
self.psycopg2_version >= self.FEATURE_VERSION_MAP['native_json']
self._has_native_jsonb = \
self.psycopg2_version >= self.FEATURE_VERSION_MAP['native_jsonb']
# http://initd.org/psycopg/docs/news.html#what-s-new-in-psycopg-2-0-9
self.supports_sane_multi_rowcount = \
self.psycopg2_version >= \
self.FEATURE_VERSION_MAP['sane_multi_rowcount']
@classmethod @classmethod
def dbapi(cls): def dbapi(cls):
psycopg = __import__('psycopg2') import psycopg2
return psycopg return psycopg2
@classmethod
def _psycopg2_extensions(cls):
from psycopg2 import extensions
return extensions
@classmethod
def _psycopg2_extras(cls):
from psycopg2 import extras
return extras
@util.memoized_property
def _isolation_lookup(self):
extensions = self._psycopg2_extensions()
return {
'AUTOCOMMIT': extensions.ISOLATION_LEVEL_AUTOCOMMIT,
'READ COMMITTED': extensions.ISOLATION_LEVEL_READ_COMMITTED,
'READ UNCOMMITTED': extensions.ISOLATION_LEVEL_READ_UNCOMMITTED,
'REPEATABLE READ': extensions.ISOLATION_LEVEL_REPEATABLE_READ,
'SERIALIZABLE': extensions.ISOLATION_LEVEL_SERIALIZABLE
}
def set_isolation_level(self, connection, level):
try:
level = self._isolation_lookup[level.replace('_', ' ')]
except KeyError:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
"Valid isolation levels for %s are %s" %
(level, self.name, ", ".join(self._isolation_lookup))
)
connection.set_isolation_level(level)
def on_connect(self): def on_connect(self):
base_on_connect = super(PGDialect_psycopg2, self).on_connect() extras = self._psycopg2_extras()
extensions = self._psycopg2_extensions()
fns = []
if self.client_encoding is not None:
def on_connect(conn):
conn.set_client_encoding(self.client_encoding)
fns.append(on_connect)
if self.isolation_level is not None:
def on_connect(conn):
self.set_isolation_level(conn, self.isolation_level)
fns.append(on_connect)
if self.dbapi and self.use_native_uuid:
def on_connect(conn):
extras.register_uuid(None, conn)
fns.append(on_connect)
if self.dbapi and self.use_native_unicode: if self.dbapi and self.use_native_unicode:
extensions = __import__('psycopg2.extensions').extensions def on_connect(conn):
def connect(conn):
extensions.register_type(extensions.UNICODE, conn) extensions.register_type(extensions.UNICODE, conn)
if base_on_connect: extensions.register_type(extensions.UNICODEARRAY, conn)
base_on_connect(conn) fns.append(on_connect)
return connect
if self.dbapi and self.use_native_hstore:
def on_connect(conn):
hstore_oids = self._hstore_oids(conn)
if hstore_oids is not None:
oid, array_oid = hstore_oids
kw = {'oid': oid}
if util.py2k:
kw['unicode'] = True
if self.psycopg2_version >= \
self.FEATURE_VERSION_MAP['array_oid']:
kw['array_oid'] = array_oid
extras.register_hstore(conn, **kw)
fns.append(on_connect)
if self.dbapi and self._json_deserializer:
def on_connect(conn):
if self._has_native_json:
extras.register_default_json(
conn, loads=self._json_deserializer)
if self._has_native_jsonb:
extras.register_default_jsonb(
conn, loads=self._json_deserializer)
fns.append(on_connect)
if fns:
def on_connect(conn):
for fn in fns:
fn(conn)
return on_connect
else: else:
return base_on_connect return None
@util.memoized_instancemethod
def _hstore_oids(self, conn):
if self.psycopg2_version >= self.FEATURE_VERSION_MAP['hstore_adapter']:
extras = self._psycopg2_extras()
oids = extras.HstoreAdapter.get_oids(conn)
if oids is not None and oids[0]:
return oids[0:2]
return None
def create_connect_args(self, url): def create_connect_args(self, url):
opts = url.translate_connect_args(username='user') opts = url.translate_connect_args(username='user')
@ -224,16 +661,42 @@ class PGDialect_psycopg2(PGDialect):
opts.update(url.query) opts.update(url.query)
return ([], opts) return ([], opts)
def is_disconnect(self, e): def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.OperationalError): if isinstance(e, self.dbapi.Error):
return 'closed the connection' in str(e) or 'connection not open' in str(e) # check the "closed" flag. this might not be
elif isinstance(e, self.dbapi.InterfaceError): # present on old psycopg2 versions. Also,
return 'connection already closed' in str(e) or 'cursor already closed' in str(e) # this flag doesn't actually help in a lot of disconnect
elif isinstance(e, self.dbapi.ProgrammingError): # situations, so don't rely on it.
# yes, it really says "losed", not "closed" if getattr(connection, 'closed', False):
return "losed the connection unexpectedly" in str(e) return True
else:
return False # checks based on strings. in the case that .closed
# didn't cut it, fall back onto these.
str_e = str(e).partition("\n")[0]
for msg in [
# these error messages from libpq: interfaces/libpq/fe-misc.c
# and interfaces/libpq/fe-secure.c.
'terminating connection',
'closed the connection',
'connection not open',
'could not receive data from server',
'could not send data to server',
# psycopg2 client errors, psycopg2/conenction.h,
# psycopg2/cursor.h
'connection already closed',
'cursor already closed',
# not sure where this path is originally from, it may
# be obsolete. It really says "losed", not "closed".
'losed the connection unexpectedly',
# these can occur in newer SSL
'connection has been closed unexpectedly',
'SSL SYSCALL error: Bad file descriptor',
'SSL SYSCALL error: EOF detected',
'SSL error: decryption failed or bad record mac',
]:
idx = str_e.find(msg)
if idx >= 0 and '"' not in str_e[:idx]:
return True
return False
dialect = PGDialect_psycopg2 dialect = PGDialect_psycopg2

View File

@ -1,18 +1,25 @@
"""Support for the PostgreSQL database via py-postgresql. # postgresql/pypostgresql.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
Connecting """
---------- .. dialect:: postgresql+pypostgresql
:name: py-postgresql
URLs are of the form `postgresql+pypostgresql://user@password@host:port/dbname[?key=value&key=value...]`. :dbapi: pypostgresql
:connectstring: postgresql+pypostgresql://user:password@host:port/dbname\
[?key=value&key=value...]
:url: http://python.projects.pgfoundry.org/
""" """
from sqlalchemy.engine import default from ... import util
import decimal from ... import types as sqltypes
from sqlalchemy import util from .base import PGDialect, PGExecutionContext
from sqlalchemy import types as sqltypes from ... import processors
from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext
from sqlalchemy import processors
class PGNumeric(sqltypes.Numeric): class PGNumeric(sqltypes.Numeric):
def bind_processor(self, dialect): def bind_processor(self, dialect):
@ -24,9 +31,11 @@ class PGNumeric(sqltypes.Numeric):
else: else:
return processors.to_float return processors.to_float
class PGExecutionContext_pypostgresql(PGExecutionContext): class PGExecutionContext_pypostgresql(PGExecutionContext):
pass pass
class PGDialect_pypostgresql(PGDialect): class PGDialect_pypostgresql(PGDialect):
driver = 'pypostgresql' driver = 'pypostgresql'
@ -36,7 +45,7 @@ class PGDialect_pypostgresql(PGDialect):
default_paramstyle = 'pyformat' default_paramstyle = 'pyformat'
# requires trunk version to support sane rowcounts # requires trunk version to support sane rowcounts
# TODO: use dbapi version information to set this flag appropariately # TODO: use dbapi version information to set this flag appropriately
supports_sane_rowcount = True supports_sane_rowcount = True
supports_sane_multi_rowcount = False supports_sane_multi_rowcount = False
@ -44,8 +53,10 @@ class PGDialect_pypostgresql(PGDialect):
colspecs = util.update_copy( colspecs = util.update_copy(
PGDialect.colspecs, PGDialect.colspecs,
{ {
sqltypes.Numeric : PGNumeric, sqltypes.Numeric: PGNumeric,
sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used
# prevents PGNumeric from being used
sqltypes.Float: sqltypes.Float,
} }
) )
@ -54,6 +65,23 @@ class PGDialect_pypostgresql(PGDialect):
from postgresql.driver import dbapi20 from postgresql.driver import dbapi20
return dbapi20 return dbapi20
_DBAPI_ERROR_NAMES = [
"Error",
"InterfaceError", "DatabaseError", "DataError",
"OperationalError", "IntegrityError", "InternalError",
"ProgrammingError", "NotSupportedError"
]
@util.memoized_property
def dbapi_exception_translation_map(self):
if self.dbapi is None:
return {}
return dict(
(getattr(self.dbapi, name).__name__, name)
for name in self._DBAPI_ERROR_NAMES
)
def create_connect_args(self, url): def create_connect_args(self, url):
opts = url.translate_connect_args(username='user') opts = url.translate_connect_args(username='user')
if 'port' in opts: if 'port' in opts:
@ -63,7 +91,7 @@ class PGDialect_pypostgresql(PGDialect):
opts.update(url.query) opts.update(url.query)
return ([], opts) return ([], opts)
def is_disconnect(self, e): def is_disconnect(self, e, connection, cursor):
return "connection is closed" in str(e) return "connection is closed" in str(e)
dialect = PGDialect_pypostgresql dialect = PGDialect_pypostgresql

View File

@ -1,19 +1,46 @@
"""Support for the PostgreSQL database via the zxjdbc JDBC connector. # postgresql/zxjdbc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
JDBC Driver # <see AUTHORS file>
----------- #
# This module is part of SQLAlchemy and is released under
The official Postgresql JDBC driver is at http://jdbc.postgresql.org/. # the MIT License: http://www.opensource.org/licenses/mit-license.php
""" """
from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector .. dialect:: postgresql+zxjdbc
from sqlalchemy.dialects.postgresql.base import PGDialect :name: zxJDBC for Jython
:dbapi: zxjdbc
:connectstring: postgresql+zxjdbc://scott:tiger@localhost/db
:driverurl: http://jdbc.postgresql.org/
"""
from ...connectors.zxJDBC import ZxJDBCConnector
from .base import PGDialect, PGExecutionContext
class PGExecutionContext_zxjdbc(PGExecutionContext):
def create_cursor(self):
cursor = self._dbapi_connection.cursor()
cursor.datahandler = self.dialect.DataHandler(cursor.datahandler)
return cursor
class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect): class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect):
jdbc_db_name = 'postgresql' jdbc_db_name = 'postgresql'
jdbc_driver_name = 'org.postgresql.Driver' jdbc_driver_name = 'org.postgresql.Driver'
execution_ctx_cls = PGExecutionContext_zxjdbc
supports_native_decimal = True
def __init__(self, *args, **kwargs):
super(PGDialect_zxjdbc, self).__init__(*args, **kwargs)
from com.ziclix.python.sql.handler import PostgresqlDataHandler
self.DataHandler = PostgresqlDataHandler
def _get_server_version_info(self, connection): def _get_server_version_info(self, connection):
return tuple(int(x) for x in connection.connection.dbversion.split('.')) parts = connection.connection.dbversion.split('.')
return tuple(int(x) for x in parts)
dialect = PGDialect_zxjdbc dialect = PGDialect_zxjdbc

View File

@ -5,20 +5,20 @@ Rules for Migrating TypeEngine classes to 0.6
a. Specifying behavior which needs to occur for bind parameters a. Specifying behavior which needs to occur for bind parameters
or result row columns. or result row columns.
b. Specifying types that are entirely specific to the database b. Specifying types that are entirely specific to the database
in use and have no analogue in the sqlalchemy.types package. in use and have no analogue in the sqlalchemy.types package.
c. Specifying types where there is an analogue in sqlalchemy.types, c. Specifying types where there is an analogue in sqlalchemy.types,
but the database in use takes vendor-specific flags for those but the database in use takes vendor-specific flags for those
types. types.
d. If a TypeEngine class doesn't provide any of this, it should be d. If a TypeEngine class doesn't provide any of this, it should be
*removed* from the dialect. *removed* from the dialect.
2. the TypeEngine classes are *no longer* used for generating DDL. Dialects 2. the TypeEngine classes are *no longer* used for generating DDL. Dialects
now have a TypeCompiler subclass which uses the same visit_XXX model as now have a TypeCompiler subclass which uses the same visit_XXX model as
other compilers. other compilers.
3. the "ischema_names" and "colspecs" dictionaries are now required members on 3. the "ischema_names" and "colspecs" dictionaries are now required members on
the Dialect class. the Dialect class.
@ -29,7 +29,7 @@ the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this
end users would never need to use _PGNumeric directly. However, if a dialect-specific end users would never need to use _PGNumeric directly. However, if a dialect-specific
type is specifying a type *or* arguments that are not present generically, it should type is specifying a type *or* arguments that are not present generically, it should
match the real name of the type on that backend, in uppercase. E.g. postgresql.INET, match the real name of the type on that backend, in uppercase. E.g. postgresql.INET,
mysql.ENUM, postgresql.ARRAY. mysql.ENUM, postgresql.ARRAY.
Or follow this handy flowchart: Or follow this handy flowchart:
@ -61,8 +61,8 @@ Or follow this handy flowchart:
| |
v v
the type should the type should
subclass the subclass the
UPPERCASE UPPERCASE
type in types.py type in types.py
(i.e. class BLOB(types.BLOB)) (i.e. class BLOB(types.BLOB))
@ -85,15 +85,15 @@ Example 4. MySQL has a SET type, there's no analogue for this in types.py. So
MySQL names it SET in the dialect's base.py, and it subclasses types.String, since MySQL names it SET in the dialect's base.py, and it subclasses types.String, since
it ultimately deals with strings. it ultimately deals with strings.
Example 5. Postgresql has a DATETIME type. The DBAPIs handle dates correctly, Example 5. PostgreSQL has a DATETIME type. The DBAPIs handle dates correctly,
and no special arguments are used in PG's DDL beyond what types.py provides. and no special arguments are used in PG's DDL beyond what types.py provides.
Postgresql dialect therefore imports types.DATETIME into its base.py. PostgreSQL dialect therefore imports types.DATETIME into its base.py.
Ideally one should be able to specify a schema using names imported completely from a Ideally one should be able to specify a schema using names imported completely from a
dialect, all matching the real name on that backend: dialect, all matching the real name on that backend:
from sqlalchemy.dialects.postgresql import base as pg from sqlalchemy.dialects.postgresql import base as pg
t = Table('mytable', metadata, t = Table('mytable', metadata,
Column('id', pg.INTEGER, primary_key=True), Column('id', pg.INTEGER, primary_key=True),
Column('name', pg.VARCHAR(300)), Column('name', pg.VARCHAR(300)),
@ -110,36 +110,36 @@ indicate a special type only available in this database, it must be *removed* fr
module and from this dictionary. module and from this dictionary.
6. "ischema_names" indicates string descriptions of types as returned from the database 6. "ischema_names" indicates string descriptions of types as returned from the database
linked to TypeEngine classes. linked to TypeEngine classes.
a. The string name should be matched to the most specific type possible within a. The string name should be matched to the most specific type possible within
sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which
case it points to a dialect type. *It doesn't matter* if the dialect has it's case it points to a dialect type. *It doesn't matter* if the dialect has its
own subclass of that type with special bind/result behavior - reflect to the types.py own subclass of that type with special bind/result behavior - reflect to the types.py
UPPERCASE type as much as possible. With very few exceptions, all types UPPERCASE type as much as possible. With very few exceptions, all types
should reflect to an UPPERCASE type. should reflect to an UPPERCASE type.
b. If the dialect contains a matching dialect-specific type that takes extra arguments b. If the dialect contains a matching dialect-specific type that takes extra arguments
which the generic one does not, then point to the dialect-specific type. E.g. which the generic one does not, then point to the dialect-specific type. E.g.
mssql.VARCHAR takes a "collation" parameter which should be preserved. mssql.VARCHAR takes a "collation" parameter which should be preserved.
5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by 5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by
a subclass of compiler.GenericTypeCompiler. a subclass of compiler.GenericTypeCompiler.
a. your TypeCompiler class will receive generic and uppercase types from a. your TypeCompiler class will receive generic and uppercase types from
sqlalchemy.types. Do not assume the presence of dialect-specific attributes on sqlalchemy.types. Do not assume the presence of dialect-specific attributes on
these types. these types.
b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with
methods that produce a different DDL name. Uppercase types don't do any kind of methods that produce a different DDL name. Uppercase types don't do any kind of
"guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in "guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in
all cases, regardless of whether or not that type is legal on the backend database. all cases, regardless of whether or not that type is legal on the backend database.
c. the visit_UPPERCASE methods *should* be overridden with methods that add additional c. the visit_UPPERCASE methods *should* be overridden with methods that add additional
arguments and flags to those types. arguments and flags to those types.
d. the visit_lowercase methods are overridden to provide an interpretation of a generic d. the visit_lowercase methods are overridden to provide an interpretation of a generic
type. E.g. visit_large_binary() might be overridden to say "return self.visit_BIT(type_)". type. E.g. visit_large_binary() might be overridden to say "return self.visit_BIT(type_)".
e. visit_lowercase methods should *never* render strings directly - it should always e. visit_lowercase methods should *never* render strings directly - it should always
be via calling a visit_UPPERCASE() method. be via calling a visit_UPPERCASE() method.

View File

@ -1,5 +1,6 @@
# engine/__init__.py # engine/__init__.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -9,7 +10,7 @@
The engine package defines the basic components used to interface The engine package defines the basic components used to interface
DB-API modules with higher-level statement construction, DB-API modules with higher-level statement construction,
connection-management, execution and result contexts. The primary connection-management, execution and result contexts. The primary
"entry point" class into this package is the Engine and it's public "entry point" class into this package is the Engine and its public
constructor ``create_engine()``. constructor ``create_engine()``.
This package includes: This package includes:
@ -50,94 +51,125 @@ url.py
within a URL. within a URL.
""" """
# not sure what this was used for from .interfaces import (
#import sqlalchemy.databases Connectable,
CreateEnginePlugin,
Dialect,
ExecutionContext,
ExceptionContext,
from sqlalchemy.engine.base import ( # backwards compat
Compiled,
TypeCompiler
)
from .base import (
Connection,
Engine,
NestedTransaction,
RootTransaction,
Transaction,
TwoPhaseTransaction,
)
from .result import (
BaseRowProxy,
BufferedColumnResultProxy, BufferedColumnResultProxy,
BufferedColumnRow, BufferedColumnRow,
BufferedRowResultProxy, BufferedRowResultProxy,
Compiled, FullyBufferedResultProxy,
Connectable,
Connection,
Dialect,
Engine,
ExecutionContext,
NestedTransaction,
ResultProxy, ResultProxy,
RootTransaction,
RowProxy, RowProxy,
Transaction, )
TwoPhaseTransaction,
TypeCompiler from .util import (
) connection_memoize
from sqlalchemy.engine import strategies )
from sqlalchemy import util
__all__ = ( from . import util, strategies
'BufferedColumnResultProxy',
'BufferedColumnRow',
'BufferedRowResultProxy',
'Compiled',
'Connectable',
'Connection',
'Dialect',
'Engine',
'ExecutionContext',
'NestedTransaction',
'ResultProxy',
'RootTransaction',
'RowProxy',
'Transaction',
'TwoPhaseTransaction',
'TypeCompiler',
'create_engine',
'engine_from_config',
)
# backwards compat
from ..sql import ddl
default_strategy = 'plain' default_strategy = 'plain'
def create_engine(*args, **kwargs): def create_engine(*args, **kwargs):
"""Create a new Engine instance. """Create a new :class:`.Engine` instance.
The standard method of specifying the engine is via URL as the The standard calling form is to send the URL as the
first positional argument, to indicate the appropriate database first positional argument, usually a string
dialect and connection arguments, with additional keyword that indicates database dialect and connection arguments::
arguments sent as options to the dialect and resulting Engine.
The URL is a string in the form
``dialect+driver://user:password@host/dbname[?key=value..]``, where engine = create_engine("postgresql://scott:tiger@localhost/test")
``dialect`` is a database name such as ``mysql``, ``oracle``,
``postgresql``, etc., and ``driver`` the name of a DBAPI, such as Additional keyword arguments may then follow it which
``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively, establish various options on the resulting :class:`.Engine`
and its underlying :class:`.Dialect` and :class:`.Pool`
constructs::
engine = create_engine("mysql://scott:tiger@hostname/dbname",
encoding='latin1', echo=True)
The string form of the URL is
``dialect[+driver]://user:password@host/dbname[?key=value..]``, where
``dialect`` is a database name such as ``mysql``, ``oracle``,
``postgresql``, etc., and ``driver`` the name of a DBAPI, such as
``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively,
the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`. the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`.
`**kwargs` takes a wide variety of options which are routed ``**kwargs`` takes a wide variety of options which are routed
towards their appropriate components. Arguments may be towards their appropriate components. Arguments may be specific to
specific to the Engine, the underlying Dialect, as well as the the :class:`.Engine`, the underlying :class:`.Dialect`, as well as the
Pool. Specific dialects also accept keyword arguments that :class:`.Pool`. Specific dialects also accept keyword arguments that
are unique to that dialect. Here, we describe the parameters are unique to that dialect. Here, we describe the parameters
that are common to most ``create_engine()`` usage. that are common to most :func:`.create_engine()` usage.
:param assert_unicode: Deprecated. A warning is raised in all cases when a non-Unicode Once established, the newly resulting :class:`.Engine` will
object is passed when SQLAlchemy would coerce into an encoding request a connection from the underlying :class:`.Pool` once
(note: but **not** when the DBAPI handles unicode objects natively). :meth:`.Engine.connect` is called, or a method which depends on it
To suppress or raise this warning to an such as :meth:`.Engine.execute` is invoked. The :class:`.Pool` in turn
error, use the Python warnings filter documented at: will establish the first actual DBAPI connection when this request
http://docs.python.org/library/warnings.html is received. The :func:`.create_engine` call itself does **not**
establish any actual DBAPI connections directly.
.. seealso::
:doc:`/core/engines`
:doc:`/dialects/index`
:ref:`connections_toplevel`
:param case_sensitive=True: if False, result column names
will match in a case-insensitive fashion, that is,
``row['SomeColumn']``.
.. versionchanged:: 0.8
By default, result row names match case-sensitively.
In version 0.7 and prior, all matches were case-insensitive.
:param connect_args: a dictionary of options which will be :param connect_args: a dictionary of options which will be
passed directly to the DBAPI's ``connect()`` method as passed directly to the DBAPI's ``connect()`` method as
additional keyword arguments. additional keyword arguments. See the example
at :ref:`custom_dbapi_args`.
:param convert_unicode=False: if set to True, all :param convert_unicode=False: if set to True, sets
String/character based types will convert Unicode values to raw the default behavior of ``convert_unicode`` on the
byte values going into the database, and all raw byte values to :class:`.String` type to ``True``, regardless
Python Unicode coming out in result sets. This is an of a setting of ``False`` on an individual
engine-wide method to provide unicode conversion across the :class:`.String` type, thus causing all :class:`.String`
board. For unicode conversion on a column-by-column level, use -based columns
the ``Unicode`` column type instead, described in `types`. to accommodate Python ``unicode`` objects. This flag
is useful as an engine-wide setting when using a
DBAPI that does not natively support Python
``unicode`` objects and raises an error when
one is received (such as pyodbc with FreeTDS).
See :class:`.String` for further details on
what this flag indicates.
:param creator: a callable which returns a DBAPI connection. :param creator: a callable which returns a DBAPI connection.
This creation function will be passed to the underlying This creation function will be passed to the underlying
@ -160,23 +192,105 @@ def create_engine(*args, **kwargs):
:ref:`dbengine_logging` for information on how to configure logging :ref:`dbengine_logging` for information on how to configure logging
directly. directly.
:param encoding='utf-8': the encoding to use for all Unicode :param encoding: Defaults to ``utf-8``. This is the string
translations, both by engine-wide unicode conversion as well as encoding used by SQLAlchemy for string encode/decode
the ``Unicode`` type object. operations which occur within SQLAlchemy, **outside of
the DBAPI.** Most modern DBAPIs feature some degree of
direct support for Python ``unicode`` objects,
what you see in Python 2 as a string of the form
``u'some string'``. For those scenarios where the
DBAPI is detected as not supporting a Python ``unicode``
object, this encoding is used to determine the
source/destination encoding. It is **not used**
for those cases where the DBAPI handles unicode
directly.
To properly configure a system to accommodate Python
``unicode`` objects, the DBAPI should be
configured to handle unicode to the greatest
degree as is appropriate - see
the notes on unicode pertaining to the specific
target database in use at :ref:`dialect_toplevel`.
Areas where string encoding may need to be accommodated
outside of the DBAPI include zero or more of:
* the values passed to bound parameters, corresponding to
the :class:`.Unicode` type or the :class:`.String` type
when ``convert_unicode`` is ``True``;
* the values returned in result set columns corresponding
to the :class:`.Unicode` type or the :class:`.String`
type when ``convert_unicode`` is ``True``;
* the string SQL statement passed to the DBAPI's
``cursor.execute()`` method;
* the string names of the keys in the bound parameter
dictionary passed to the DBAPI's ``cursor.execute()``
as well as ``cursor.setinputsizes()`` methods;
* the string column names retrieved from the DBAPI's
``cursor.description`` attribute.
When using Python 3, the DBAPI is required to support
*all* of the above values as Python ``unicode`` objects,
which in Python 3 are just known as ``str``. In Python 2,
the DBAPI does not specify unicode behavior at all,
so SQLAlchemy must make decisions for each of the above
values on a per-DBAPI basis - implementations are
completely inconsistent in their behavior.
:param execution_options: Dictionary execution options which will
be applied to all connections. See
:meth:`~sqlalchemy.engine.Connection.execution_options`
:param implicit_returning=True: When ``True``, a RETURNING-
compatible construct, if available, will be used to
fetch newly generated primary key values when a single row
INSERT statement is emitted with no existing returning()
clause. This applies to those backends which support RETURNING
or a compatible construct, including PostgreSQL, Firebird, Oracle,
Microsoft SQL Server. Set this to ``False`` to disable
the automatic usage of RETURNING.
:param isolation_level: this string parameter is interpreted by various
dialects in order to affect the transaction isolation level of the
database connection. The parameter essentially accepts some subset of
these string arguments: ``"SERIALIZABLE"``, ``"REPEATABLE_READ"``,
``"READ_COMMITTED"``, ``"READ_UNCOMMITTED"`` and ``"AUTOCOMMIT"``.
Behavior here varies per backend, and
individual dialects should be consulted directly.
Note that the isolation level can also be set on a per-:class:`.Connection`
basis as well, using the
:paramref:`.Connection.execution_options.isolation_level`
feature.
.. seealso::
:attr:`.Connection.default_isolation_level` - view default level
:paramref:`.Connection.execution_options.isolation_level`
- set per :class:`.Connection` isolation level
:ref:`SQLite Transaction Isolation <sqlite_isolation_level>`
:ref:`PostgreSQL Transaction Isolation <postgresql_isolation_level>`
:ref:`MySQL Transaction Isolation <mysql_isolation_level>`
:ref:`session_transaction_isolation` - for the ORM
:param label_length=None: optional integer value which limits :param label_length=None: optional integer value which limits
the size of dynamically generated column labels to that many the size of dynamically generated column labels to that many
characters. If less than 6, labels are generated as characters. If less than 6, labels are generated as
"_(counter)". If ``None``, the value of "_(counter)". If ``None``, the value of
``dialect.max_identifier_length`` is used instead. ``dialect.max_identifier_length`` is used instead.
:param listeners: A list of one or more :param listeners: A list of one or more
:class:`~sqlalchemy.interfaces.PoolListener` objects which will :class:`~sqlalchemy.interfaces.PoolListener` objects which will
receive connection pool events. receive connection pool events.
:param logging_name: String identifier which will be used within :param logging_name: String identifier which will be used within
the "name" field of logging records generated within the the "name" field of logging records generated within the
"sqlalchemy.engine" logger. Defaults to a hexstring of the "sqlalchemy.engine" logger. Defaults to a hexstring of the
object's id. object's id.
:param max_overflow=10: the number of connections to allow in :param max_overflow=10: the number of connections to allow in
@ -184,10 +298,24 @@ def create_engine(*args, **kwargs):
opened above and beyond the pool_size setting, which defaults opened above and beyond the pool_size setting, which defaults
to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`. to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`.
:param module=None: used by database implementations which :param module=None: reference to a Python module object (the module
support multiple DBAPI modules, this is a reference to a DBAPI2 itself, not its string name). Specifies an alternate DBAPI module to
module to be used instead of the engine's default module. For be used by the engine's dialect. Each sub-dialect references a
PostgreSQL, the default is psycopg2. For Oracle, it's cx_Oracle. specific DBAPI which will be imported before first connect. This
parameter causes the import to be bypassed, and the given module to
be used instead. Can be used for testing of DBAPIs as well as to
inject "mock" DBAPI implementations into the :class:`.Engine`.
:param paramstyle=None: The `paramstyle <http://legacy.python.org/dev/peps/pep-0249/#paramstyle>`_
to use when rendering bound parameters. This style defaults to the
one recommended by the DBAPI itself, which is retrieved from the
``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept
more than one paramstyle, and in particular it may be desirable
to change a "named" paramstyle into a "positional" one, or vice versa.
When this attribute is passed, it should be one of the values
``"qmark"``, ``"numeric"``, ``"named"``, ``"format"`` or
``"pyformat"``, and should correspond to a parameter style known
to be supported by the DBAPI in use.
:param pool=None: an already-constructed instance of :param pool=None: an already-constructed instance of
:class:`~sqlalchemy.pool.Pool`, such as a :class:`~sqlalchemy.pool.Pool`, such as a
@ -195,7 +323,7 @@ def create_engine(*args, **kwargs):
pool will be used directly as the underlying connection pool pool will be used directly as the underlying connection pool
for the engine, bypassing whatever connection parameters are for the engine, bypassing whatever connection parameters are
present in the URL argument. For information on constructing present in the URL argument. For information on constructing
connection pools manually, see `pooling`. connection pools manually, see :ref:`pooling_toplevel`.
:param poolclass=None: a :class:`~sqlalchemy.pool.Pool` :param poolclass=None: a :class:`~sqlalchemy.pool.Pool`
subclass, which will be used to create a connection pool subclass, which will be used to create a connection pool
@ -205,70 +333,102 @@ def create_engine(*args, **kwargs):
of pool to be used. of pool to be used.
:param pool_logging_name: String identifier which will be used within :param pool_logging_name: String identifier which will be used within
the "name" field of logging records generated within the the "name" field of logging records generated within the
"sqlalchemy.pool" logger. Defaults to a hexstring of the object's "sqlalchemy.pool" logger. Defaults to a hexstring of the object's
id. id.
:param pool_size=5: the number of connections to keep open :param pool_size=5: the number of connections to keep open
inside the connection pool. This used with :class:`~sqlalchemy.pool.QueuePool` as inside the connection pool. This used with
well as :class:`~sqlalchemy.pool.SingletonThreadPool`. :class:`~sqlalchemy.pool.QueuePool` as
well as :class:`~sqlalchemy.pool.SingletonThreadPool`. With
:class:`~sqlalchemy.pool.QueuePool`, a ``pool_size`` setting
of 0 indicates no limit; to disable pooling, set ``poolclass`` to
:class:`~sqlalchemy.pool.NullPool` instead.
:param pool_recycle=-1: this setting causes the pool to recycle :param pool_recycle=-1: this setting causes the pool to recycle
connections after the given number of seconds has passed. It connections after the given number of seconds has passed. It
defaults to -1, or no timeout. For example, setting to 3600 defaults to -1, or no timeout. For example, setting to 3600
means connections will be recycled after one hour. Note that means connections will be recycled after one hour. Note that
MySQL in particular will ``disconnect automatically`` if no MySQL in particular will disconnect automatically if no
activity is detected on a connection for eight hours (although activity is detected on a connection for eight hours (although
this is configurable with the MySQLDB connection itself and the this is configurable with the MySQLDB connection itself and the
server configuration as well). server configuration as well).
:param pool_reset_on_return='rollback': set the "reset on return"
behavior of the pool, which is whether ``rollback()``,
``commit()``, or nothing is called upon connections
being returned to the pool. See the docstring for
``reset_on_return`` at :class:`.Pool`.
.. versionadded:: 0.7.6
:param pool_timeout=30: number of seconds to wait before giving :param pool_timeout=30: number of seconds to wait before giving
up on getting a connection from the pool. This is only used up on getting a connection from the pool. This is only used
with :class:`~sqlalchemy.pool.QueuePool`. with :class:`~sqlalchemy.pool.QueuePool`.
:param strategy='plain': used to invoke alternate :class:`~sqlalchemy.engine.base.Engine.` :param strategy='plain': selects alternate engine implementations.
implementations. Currently available is the ``threadlocal`` Currently available are:
strategy, which is described in :ref:`threadlocal_strategy`.
* the ``threadlocal`` strategy, which is described in
:ref:`threadlocal_strategy`;
* the ``mock`` strategy, which dispatches all statement
execution to a function passed as the argument ``executor``.
See `example in the FAQ
<http://docs.sqlalchemy.org/en/latest/faq/metadata_schema.html#how-can-i-get-the-create-table-drop-table-output-as-a-string>`_.
:param executor=None: a function taking arguments
``(sql, *multiparams, **params)``, to which the ``mock`` strategy will
dispatch all statement execution. Used only by ``strategy='mock'``.
""" """
strategy = kwargs.pop('strategy', default_strategy) strategy = kwargs.pop('strategy', default_strategy)
strategy = strategies.strategies[strategy] strategy = strategies.strategies[strategy]
return strategy.create(*args, **kwargs) return strategy.create(*args, **kwargs)
def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs): def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
"""Create a new Engine instance using a configuration dictionary. """Create a new Engine instance using a configuration dictionary.
The dictionary is typically produced from a config file where keys The dictionary is typically produced from a config file.
are prefixed, such as sqlalchemy.url, sqlalchemy.echo, etc. The
'prefix' argument indicates the prefix to be searched for. The keys of interest to ``engine_from_config()`` should be prefixed, e.g.
``sqlalchemy.url``, ``sqlalchemy.echo``, etc. The 'prefix' argument
indicates the prefix to be searched for. Each matching key (after the
prefix is stripped) is treated as though it were the corresponding keyword
argument to a :func:`.create_engine` call.
The only required key is (assuming the default prefix) ``sqlalchemy.url``,
which provides the :ref:`database URL <database_urls>`.
A select set of keyword arguments will be "coerced" to their A select set of keyword arguments will be "coerced" to their
expected type based on string values. In a future release, this expected type based on string values. The set of arguments
functionality will be expanded and include dialect-specific is extensible per-dialect using the ``engine_config_types`` accessor.
arguments.
:param configuration: A dictionary (typically produced from a config file,
but this is not a requirement). Items whose keys start with the value
of 'prefix' will have that prefix stripped, and will then be passed to
:ref:`create_engine`.
:param prefix: Prefix to match and then strip from keys
in 'configuration'.
:param kwargs: Each keyword argument to ``engine_from_config()`` itself
overrides the corresponding item taken from the 'configuration'
dictionary. Keyword arguments should *not* be prefixed.
""" """
opts = _coerce_config(configuration, prefix)
opts.update(kwargs)
url = opts.pop('url')
return create_engine(url, **opts)
def _coerce_config(configuration, prefix):
"""Convert configuration values to expected types."""
options = dict((key[len(prefix):], configuration[key]) options = dict((key[len(prefix):], configuration[key])
for key in configuration for key in configuration
if key.startswith(prefix)) if key.startswith(prefix))
for option, type_ in ( options['_coerce_config'] = True
('convert_unicode', bool), options.update(kwargs)
('pool_timeout', int), url = options.pop('url')
('echo', bool), return create_engine(url, **options)
('echo_pool', bool),
('pool_recycle', int),
('pool_size', int), __all__ = (
('max_overflow', int), 'create_engine',
('pool_threadlocal', bool), 'engine_from_config',
): )
util.coerce_kw_type(options, option, type_)
return options

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,10 @@
# engine/reflection.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Provides an abstraction for obtaining database schema information. """Provides an abstraction for obtaining database schema information.
Usage Notes: Usage Notes:
@ -18,11 +25,14 @@ methods such as get_table_names, get_columns, etc.
'name' attribute.. 'name' attribute..
""" """
import sqlalchemy from .. import exc, sql
from sqlalchemy import exc, sql from ..sql import schema as sa_schema
from sqlalchemy import util from .. import util
from sqlalchemy.types import TypeEngine from ..sql.type_api import TypeEngine
from sqlalchemy import schema as sa_schema from ..util import deprecated
from ..util import topological
from .. import inspection
from .base import Connectable
@util.decorator @util.decorator
@ -31,10 +41,14 @@ def cache(fn, self, con, *args, **kw):
if info_cache is None: if info_cache is None:
return fn(self, con, *args, **kw) return fn(self, con, *args, **kw)
key = ( key = (
fn.__name__, fn.__name__,
tuple(a for a in args if isinstance(a, basestring)), tuple(a for a in args if isinstance(a, util.string_types)),
tuple((k, v) for k, v in kw.iteritems() if isinstance(v, (basestring, int, float))) tuple((k, v) for k, v in kw.items() if
) isinstance(v,
util.string_types + util.int_types + (float, )
)
)
)
ret = info_cache.get(key) ret = info_cache.get(key)
if ret is None: if ret is None:
ret = fn(self, con, *args, **kw) ret = fn(self, con, *args, **kw)
@ -45,33 +59,94 @@ def cache(fn, self, con, *args, **kw):
class Inspector(object): class Inspector(object):
"""Performs database schema inspection. """Performs database schema inspection.
The Inspector acts as a proxy to the dialects' reflection methods and The Inspector acts as a proxy to the reflection methods of the
provides higher level functions for accessing database schema information. :class:`~sqlalchemy.engine.interfaces.Dialect`, providing a
consistent interface as well as caching support for previously
fetched metadata.
A :class:`.Inspector` object is usually created via the
:func:`.inspect` function::
from sqlalchemy import inspect, create_engine
engine = create_engine('...')
insp = inspect(engine)
The inspection method above is equivalent to using the
:meth:`.Inspector.from_engine` method, i.e.::
engine = create_engine('...')
insp = Inspector.from_engine(engine)
Where above, the :class:`~sqlalchemy.engine.interfaces.Dialect` may opt
to return an :class:`.Inspector` subclass that provides additional
methods specific to the dialect's target database.
""" """
def __init__(self, conn): def __init__(self, bind):
"""Initialize the instance. """Initialize a new :class:`.Inspector`.
:param bind: a :class:`~sqlalchemy.engine.Connectable`,
which is typically an instance of
:class:`~sqlalchemy.engine.Engine` or
:class:`~sqlalchemy.engine.Connection`.
For a dialect-specific instance of :class:`.Inspector`, see
:meth:`.Inspector.from_engine`
:param conn: a :class:`~sqlalchemy.engine.base.Connectable`
""" """
# this might not be a connection, it could be an engine.
self.bind = bind
self.conn = conn
# set the engine # set the engine
if hasattr(conn, 'engine'): if hasattr(bind, 'engine'):
self.engine = conn.engine self.engine = bind.engine
else: else:
self.engine = conn self.engine = bind
if self.engine is bind:
# if engine, ensure initialized
bind.connect().close()
self.dialect = self.engine.dialect self.dialect = self.engine.dialect
self.info_cache = {} self.info_cache = {}
@classmethod @classmethod
def from_engine(cls, engine): def from_engine(cls, bind):
if hasattr(engine.dialect, 'inspector'): """Construct a new dialect-specific Inspector object from the given
return engine.dialect.inspector(engine) engine or connection.
return Inspector(engine)
:param bind: a :class:`~sqlalchemy.engine.Connectable`,
which is typically an instance of
:class:`~sqlalchemy.engine.Engine` or
:class:`~sqlalchemy.engine.Connection`.
This method differs from direct a direct constructor call of
:class:`.Inspector` in that the
:class:`~sqlalchemy.engine.interfaces.Dialect` is given a chance to
provide a dialect-specific :class:`.Inspector` instance, which may
provide additional methods.
See the example at :class:`.Inspector`.
"""
if hasattr(bind.dialect, 'inspector'):
return bind.dialect.inspector(bind)
return Inspector(bind)
@inspection._inspects(Connectable)
def _insp(bind):
return Inspector.from_engine(bind)
@property @property
def default_schema_name(self): def default_schema_name(self):
"""Return the default schema name presented by the dialect
for the current engine's database user.
E.g. this is typically ``public`` for PostgreSQL and ``dbo``
for SQL Server.
"""
return self.dialect.default_schema_name return self.dialect.default_schema_name
def get_schema_names(self): def get_schema_names(self):
@ -79,70 +154,185 @@ class Inspector(object):
""" """
if hasattr(self.dialect, 'get_schema_names'): if hasattr(self.dialect, 'get_schema_names'):
return self.dialect.get_schema_names(self.conn, return self.dialect.get_schema_names(self.bind,
info_cache=self.info_cache) info_cache=self.info_cache)
return [] return []
def get_table_names(self, schema=None, order_by=None): def get_table_names(self, schema=None, order_by=None):
"""Return all table names in `schema`. """Return all table names in referred to within a particular schema.
The names are expected to be real tables only, not views.
Views are instead returned using the :meth:`.Inspector.get_view_names`
method.
:param schema: Schema name. If ``schema`` is left at ``None``, the
database's default schema is
used, else the named schema is searched. If the database does not
support named schemas, behavior is undefined if ``schema`` is not
passed as ``None``. For special quoting, use :class:`.quoted_name`.
:param schema: Optional, retrieve names from a non-default schema.
:param order_by: Optional, may be the string "foreign_key" to sort :param order_by: Optional, may be the string "foreign_key" to sort
the result on foreign key dependencies. the result on foreign key dependencies. Does not automatically
resolve cycles, and will raise :class:`.CircularDependencyError`
if cycles exist.
.. deprecated:: 1.0.0 - see
:meth:`.Inspector.get_sorted_table_and_fkc_names` for a version
of this which resolves foreign key cycles between tables
automatically.
.. versionchanged:: 0.8 the "foreign_key" sorting sorts tables
in order of dependee to dependent; that is, in creation
order, rather than in drop order. This is to maintain
consistency with similar features such as
:attr:`.MetaData.sorted_tables` and :func:`.util.sort_tables`.
.. seealso::
:meth:`.Inspector.get_sorted_table_and_fkc_names`
:attr:`.MetaData.sorted_tables`
This should probably not return view names or maybe it should return
them with an indicator t or v.
""" """
if hasattr(self.dialect, 'get_table_names'): if hasattr(self.dialect, 'get_table_names'):
tnames = self.dialect.get_table_names(self.conn, tnames = self.dialect.get_table_names(
schema, self.bind, schema, info_cache=self.info_cache)
info_cache=self.info_cache)
else: else:
tnames = self.engine.table_names(schema) tnames = self.engine.table_names(schema)
if order_by == 'foreign_key': if order_by == 'foreign_key':
ordered_tnames = tnames[:] tuples = []
# Order based on foreign key dependencies.
for tname in tnames: for tname in tnames:
table_pos = tnames.index(tname) for fkey in self.get_foreign_keys(tname, schema):
fkeys = self.get_foreign_keys(tname, schema) if tname != fkey['referred_table']:
for fkey in fkeys: tuples.append((fkey['referred_table'], tname))
rtable = fkey['referred_table'] tnames = list(topological.sort(tuples, tnames))
if rtable in ordered_tnames:
ref_pos = ordered_tnames.index(rtable)
# Make sure it's lower in the list than anything it
# references.
if table_pos > ref_pos:
ordered_tnames.pop(table_pos) # rtable moves up 1
# insert just below rtable
ordered_tnames.index(ref_pos, tname)
tnames = ordered_tnames
return tnames return tnames
def get_sorted_table_and_fkc_names(self, schema=None):
"""Return dependency-sorted table and foreign key constraint names in
referred to within a particular schema.
This will yield 2-tuples of
``(tablename, [(tname, fkname), (tname, fkname), ...])``
consisting of table names in CREATE order grouped with the foreign key
constraint names that are not detected as belonging to a cycle.
The final element
will be ``(None, [(tname, fkname), (tname, fkname), ..])``
which will consist of remaining
foreign key constraint names that would require a separate CREATE
step after-the-fact, based on dependencies between tables.
.. versionadded:: 1.0.-
.. seealso::
:meth:`.Inspector.get_table_names`
:func:`.sort_tables_and_constraints` - similar method which works
with an already-given :class:`.MetaData`.
"""
if hasattr(self.dialect, 'get_table_names'):
tnames = self.dialect.get_table_names(
self.bind, schema, info_cache=self.info_cache)
else:
tnames = self.engine.table_names(schema)
tuples = set()
remaining_fkcs = set()
fknames_for_table = {}
for tname in tnames:
fkeys = self.get_foreign_keys(tname, schema)
fknames_for_table[tname] = set(
[fk['name'] for fk in fkeys]
)
for fkey in fkeys:
if tname != fkey['referred_table']:
tuples.add((fkey['referred_table'], tname))
try:
candidate_sort = list(topological.sort(tuples, tnames))
except exc.CircularDependencyError as err:
for edge in err.edges:
tuples.remove(edge)
remaining_fkcs.update(
(edge[1], fkc)
for fkc in fknames_for_table[edge[1]]
)
candidate_sort = list(topological.sort(tuples, tnames))
return [
(tname, fknames_for_table[tname].difference(remaining_fkcs))
for tname in candidate_sort
] + [(None, list(remaining_fkcs))]
def get_temp_table_names(self):
"""return a list of temporary table names for the current bind.
This method is unsupported by most dialects; currently
only SQLite implements it.
.. versionadded:: 1.0.0
"""
return self.dialect.get_temp_table_names(
self.bind, info_cache=self.info_cache)
def get_temp_view_names(self):
"""return a list of temporary view names for the current bind.
This method is unsupported by most dialects; currently
only SQLite implements it.
.. versionadded:: 1.0.0
"""
return self.dialect.get_temp_view_names(
self.bind, info_cache=self.info_cache)
def get_table_options(self, table_name, schema=None, **kw): def get_table_options(self, table_name, schema=None, **kw):
"""Return a dictionary of options specified when the table of the
given name was created.
This currently includes some options that apply to MySQL tables.
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
"""
if hasattr(self.dialect, 'get_table_options'): if hasattr(self.dialect, 'get_table_options'):
return self.dialect.get_table_options(self.conn, table_name, schema, return self.dialect.get_table_options(
info_cache=self.info_cache, self.bind, table_name, schema,
**kw) info_cache=self.info_cache, **kw)
return {} return {}
def get_view_names(self, schema=None): def get_view_names(self, schema=None):
"""Return all view names in `schema`. """Return all view names in `schema`.
:param schema: Optional, retrieve names from a non-default schema. :param schema: Optional, retrieve names from a non-default schema.
For special quoting, use :class:`.quoted_name`.
""" """
return self.dialect.get_view_names(self.conn, schema, return self.dialect.get_view_names(self.bind, schema,
info_cache=self.info_cache) info_cache=self.info_cache)
def get_view_definition(self, view_name, schema=None): def get_view_definition(self, view_name, schema=None):
"""Return definition for `view_name`. """Return definition for `view_name`.
:param schema: Optional, retrieve names from a non-default schema. :param schema: Optional, retrieve names from a non-default schema.
For special quoting, use :class:`.quoted_name`.
""" """
return self.dialect.get_view_definition( return self.dialect.get_view_definition(
self.conn, view_name, schema, info_cache=self.info_cache) self.bind, view_name, schema, info_cache=self.info_cache)
def get_columns(self, table_name, schema=None, **kw): def get_columns(self, table_name, schema=None, **kw):
"""Return information about columns in `table_name`. """Return information about columns in `table_name`.
@ -150,23 +340,31 @@ class Inspector(object):
Given a string `table_name` and an optional string `schema`, return Given a string `table_name` and an optional string `schema`, return
column information as a list of dicts with these keys: column information as a list of dicts with these keys:
name * ``name`` - the column's name
the column's name
type * ``type`` - the type of this column; an instance of
:class:`~sqlalchemy.types.TypeEngine` :class:`~sqlalchemy.types.TypeEngine`
nullable * ``nullable`` - boolean flag if the column is NULL or NOT NULL
boolean
default * ``default`` - the column's server default value - this is returned
the column's default value as a string SQL expression.
* ``attrs`` - dict containing optional column attributes
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
:return: list of dictionaries, each representing the definition of
a database column.
attrs
dict containing optional column attributes
""" """
col_defs = self.dialect.get_columns(self.conn, table_name, schema, col_defs = self.dialect.get_columns(self.bind, table_name, schema,
info_cache=self.info_cache, info_cache=self.info_cache,
**kw) **kw)
for col_def in col_defs: for col_def in col_defs:
@ -176,6 +374,8 @@ class Inspector(object):
col_def['type'] = coltype() col_def['type'] = coltype()
return col_defs return col_defs
@deprecated('0.7', 'Call to deprecated method get_primary_keys.'
' Use get_pk_constraint instead.')
def get_primary_keys(self, table_name, schema=None, **kw): def get_primary_keys(self, table_name, schema=None, **kw):
"""Return information about primary keys in `table_name`. """Return information about primary keys in `table_name`.
@ -183,12 +383,34 @@ class Inspector(object):
primary key information as a list of column names. primary key information as a list of column names.
""" """
pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema, return self.dialect.get_pk_constraint(self.bind, table_name, schema,
info_cache=self.info_cache,
**kw)['constrained_columns']
def get_pk_constraint(self, table_name, schema=None, **kw):
"""Return information about primary key constraint on `table_name`.
Given a string `table_name`, and an optional string `schema`, return
primary key information as a dictionary with these keys:
constrained_columns
a list of column names that make up the primary key
name
optional name of the primary key constraint.
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
"""
return self.dialect.get_pk_constraint(self.bind, table_name, schema,
info_cache=self.info_cache, info_cache=self.info_cache,
**kw) **kw)
return pkeys
def get_foreign_keys(self, table_name, schema=None, **kw): def get_foreign_keys(self, table_name, schema=None, **kw):
"""Return information about foreign_keys in `table_name`. """Return information about foreign_keys in `table_name`.
@ -208,15 +430,21 @@ class Inspector(object):
a list of column names in the referred table that correspond to a list of column names in the referred table that correspond to
constrained_columns constrained_columns
\**kw name
other options passed to the dialect's get_foreign_keys() method. optional name of the foreign key constraint.
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
""" """
fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema, return self.dialect.get_foreign_keys(self.bind, table_name, schema,
info_cache=self.info_cache, info_cache=self.info_cache,
**kw) **kw)
return fk_defs
def get_indexes(self, table_name, schema=None, **kw): def get_indexes(self, table_name, schema=None, **kw):
"""Return information about indexes in `table_name`. """Return information about indexes in `table_name`.
@ -232,104 +460,261 @@ class Inspector(object):
unique unique
boolean boolean
\**kw dialect_options
other options passed to the dialect's get_indexes() method. dict of dialect-specific index options. May not be present
for all dialects.
.. versionadded:: 1.0.0
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
""" """
indexes = self.dialect.get_indexes(self.conn, table_name, return self.dialect.get_indexes(self.bind, table_name,
schema, schema,
info_cache=self.info_cache, **kw) info_cache=self.info_cache, **kw)
return indexes
def reflecttable(self, table, include_columns): def get_unique_constraints(self, table_name, schema=None, **kw):
"""Return information about unique constraints in `table_name`.
dialect = self.conn.dialect Given a string `table_name` and an optional string `schema`, return
unique constraint information as a list of dicts with these keys:
# MySQL dialect does this. Applicable with other dialects? name
if hasattr(dialect, '_connection_charset') \ the unique constraint's name
and hasattr(dialect, '_adjust_casing'):
charset = dialect._connection_charset
dialect._adjust_casing(table)
# table attributes we might need. column_names
reflection_options = dict( list of column names in order
(k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs)
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
.. versionadded:: 0.8.4
"""
return self.dialect.get_unique_constraints(
self.bind, table_name, schema, info_cache=self.info_cache, **kw)
def get_check_constraints(self, table_name, schema=None, **kw):
"""Return information about check constraints in `table_name`.
Given a string `table_name` and an optional string `schema`, return
check constraint information as a list of dicts with these keys:
name
the check constraint's name
sqltext
the check constraint's SQL expression
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
:param schema: string schema name; if omitted, uses the default schema
of the database connection. For special quoting,
use :class:`.quoted_name`.
.. versionadded:: 1.1.0
"""
return self.dialect.get_check_constraints(
self.bind, table_name, schema, info_cache=self.info_cache, **kw)
def reflecttable(self, table, include_columns, exclude_columns=(),
_extend_on=None):
"""Given a Table object, load its internal constructs based on
introspection.
This is the underlying method used by most dialects to produce
table reflection. Direct usage is like::
from sqlalchemy import create_engine, MetaData, Table
from sqlalchemy.engine import reflection
engine = create_engine('...')
meta = MetaData()
user_table = Table('user', meta)
insp = Inspector.from_engine(engine)
insp.reflecttable(user_table, None)
:param table: a :class:`~sqlalchemy.schema.Table` instance.
:param include_columns: a list of string column names to include
in the reflection process. If ``None``, all columns are reflected.
"""
if _extend_on is not None:
if table in _extend_on:
return
else:
_extend_on.add(table)
dialect = self.bind.dialect
schema = self.bind.schema_for_object(table)
schema = table.schema
table_name = table.name table_name = table.name
# apply table options # get table-level arguments that are specifically
tbl_opts = self.get_table_options(table_name, schema, **table.kwargs) # intended for reflection, e.g. oracle_resolve_synonyms.
# these are unconditionally passed to related Table
# objects
reflection_options = dict(
(k, table.dialect_kwargs.get(k))
for k in dialect.reflection_options
if k in table.dialect_kwargs
)
# reflect table options, like mysql_engine
tbl_opts = self.get_table_options(
table_name, schema, **table.dialect_kwargs)
if tbl_opts: if tbl_opts:
table.kwargs.update(tbl_opts) # add additional kwargs to the Table if the dialect
# returned them
table._validate_dialect_kwargs(tbl_opts)
# table.kwargs will need to be passed to each reflection method. Make if util.py2k:
# sure keywords are strings. if isinstance(schema, str):
tblkw = table.kwargs.copy() schema = schema.decode(dialect.encoding)
for (k, v) in tblkw.items(): if isinstance(table_name, str):
del tblkw[k] table_name = table_name.decode(dialect.encoding)
tblkw[str(k)] = v
# Py2K
if isinstance(schema, str):
schema = schema.decode(dialect.encoding)
if isinstance(table_name, str):
table_name = table_name.decode(dialect.encoding)
# end Py2K
# columns
found_table = False found_table = False
for col_d in self.get_columns(table_name, schema, **tblkw): cols_by_orig_name = {}
found_table = True
name = col_d['name']
if include_columns and name not in include_columns:
continue
coltype = col_d['type'] for col_d in self.get_columns(
col_kw = { table_name, schema, **table.dialect_kwargs):
'nullable':col_d['nullable'], found_table = True
}
if 'autoincrement' in col_d: self._reflect_column(
col_kw['autoincrement'] = col_d['autoincrement'] table, col_d, include_columns,
if 'quote' in col_d: exclude_columns, cols_by_orig_name)
col_kw['quote'] = col_d['quote']
colargs = []
if col_d.get('default') is not None:
# the "default" value is assumed to be a literal SQL expression,
# so is wrapped in text() so that no quoting occurs on re-issuance.
colargs.append(sa_schema.DefaultClause(sql.text(col_d['default'])))
if 'sequence' in col_d:
# TODO: mssql, maxdb and sybase are using this.
seq = col_d['sequence']
sequence = sa_schema.Sequence(seq['name'], 1, 1)
if 'start' in seq:
sequence.start = seq['start']
if 'increment' in seq:
sequence.increment = seq['increment']
colargs.append(sequence)
col = sa_schema.Column(name, coltype, *colargs, **col_kw)
table.append_column(col)
if not found_table: if not found_table:
raise exc.NoSuchTableError(table.name) raise exc.NoSuchTableError(table.name)
# Primary keys self._reflect_pk(
primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[ table_name, schema, table, cols_by_orig_name, exclude_columns)
table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw)
if pk in table.c
])
table.append_constraint(primary_key_constraint) self._reflect_fk(
table_name, schema, table, cols_by_orig_name,
exclude_columns, _extend_on, reflection_options)
# Foreign keys self._reflect_indexes(
fkeys = self.get_foreign_keys(table_name, schema, **tblkw) table_name, schema, table, cols_by_orig_name,
include_columns, exclude_columns, reflection_options)
self._reflect_unique_constraints(
table_name, schema, table, cols_by_orig_name,
include_columns, exclude_columns, reflection_options)
self._reflect_check_constraints(
table_name, schema, table, cols_by_orig_name,
include_columns, exclude_columns, reflection_options)
def _reflect_column(
self, table, col_d, include_columns,
exclude_columns, cols_by_orig_name):
orig_name = col_d['name']
table.dispatch.column_reflect(self, table, col_d)
# fetch name again as column_reflect is allowed to
# change it
name = col_d['name']
if (include_columns and name not in include_columns) \
or (exclude_columns and name in exclude_columns):
return
coltype = col_d['type']
col_kw = dict(
(k, col_d[k])
for k in ['nullable', 'autoincrement', 'quote', 'info', 'key']
if k in col_d
)
colargs = []
if col_d.get('default') is not None:
default = col_d['default']
if isinstance(default, sql.elements.TextClause):
default = sa_schema.DefaultClause(default, _reflected=True)
elif not isinstance(default, sa_schema.FetchedValue):
default = sa_schema.DefaultClause(
sql.text(col_d['default']), _reflected=True)
colargs.append(default)
if 'sequence' in col_d:
self._reflect_col_sequence(col_d, colargs)
cols_by_orig_name[orig_name] = col = \
sa_schema.Column(name, coltype, *colargs, **col_kw)
if col.key in table.primary_key:
col.primary_key = True
table.append_column(col)
def _reflect_col_sequence(self, col_d, colargs):
if 'sequence' in col_d:
# TODO: mssql and sybase are using this.
seq = col_d['sequence']
sequence = sa_schema.Sequence(seq['name'], 1, 1)
if 'start' in seq:
sequence.start = seq['start']
if 'increment' in seq:
sequence.increment = seq['increment']
colargs.append(sequence)
def _reflect_pk(
self, table_name, schema, table,
cols_by_orig_name, exclude_columns):
pk_cons = self.get_pk_constraint(
table_name, schema, **table.dialect_kwargs)
if pk_cons:
pk_cols = [
cols_by_orig_name[pk]
for pk in pk_cons['constrained_columns']
if pk in cols_by_orig_name and pk not in exclude_columns
]
# update pk constraint name
table.primary_key.name = pk_cons.get('name')
# tell the PKConstraint to re-initialize
# its column collection
table.primary_key._reload(pk_cols)
def _reflect_fk(
self, table_name, schema, table, cols_by_orig_name,
exclude_columns, _extend_on, reflection_options):
fkeys = self.get_foreign_keys(
table_name, schema, **table.dialect_kwargs)
for fkey_d in fkeys: for fkey_d in fkeys:
conname = fkey_d['name'] conname = fkey_d['name']
constrained_columns = fkey_d['constrained_columns'] # look for columns by orig name in cols_by_orig_name,
# but support columns that are in-Python only as fallback
constrained_columns = [
cols_by_orig_name[c].key
if c in cols_by_orig_name else c
for c in fkey_d['constrained_columns']
]
if exclude_columns and set(constrained_columns).intersection(
exclude_columns):
continue
referred_schema = fkey_d['referred_schema'] referred_schema = fkey_d['referred_schema']
referred_table = fkey_d['referred_table'] referred_table = fkey_d['referred_table']
referred_columns = fkey_d['referred_columns'] referred_columns = fkey_d['referred_columns']
@ -337,7 +722,8 @@ class Inspector(object):
if referred_schema is not None: if referred_schema is not None:
sa_schema.Table(referred_table, table.metadata, sa_schema.Table(referred_table, table.metadata,
autoload=True, schema=referred_schema, autoload=True, schema=referred_schema,
autoload_with=self.conn, autoload_with=self.bind,
_extend_on=_extend_on,
**reflection_options **reflection_options
) )
for column in referred_columns: for column in referred_columns:
@ -345,26 +731,113 @@ class Inspector(object):
[referred_schema, referred_table, column])) [referred_schema, referred_table, column]))
else: else:
sa_schema.Table(referred_table, table.metadata, autoload=True, sa_schema.Table(referred_table, table.metadata, autoload=True,
autoload_with=self.conn, autoload_with=self.bind,
schema=sa_schema.BLANK_SCHEMA,
_extend_on=_extend_on,
**reflection_options **reflection_options
) )
for column in referred_columns: for column in referred_columns:
refspec.append(".".join([referred_table, column])) refspec.append(".".join([referred_table, column]))
if 'options' in fkey_d:
options = fkey_d['options']
else:
options = {}
table.append_constraint( table.append_constraint(
sa_schema.ForeignKeyConstraint(constrained_columns, refspec, sa_schema.ForeignKeyConstraint(constrained_columns, refspec,
conname, link_to_name=True)) conname, link_to_name=True,
# Indexes **options))
def _reflect_indexes(
self, table_name, schema, table, cols_by_orig_name,
include_columns, exclude_columns, reflection_options):
# Indexes
indexes = self.get_indexes(table_name, schema) indexes = self.get_indexes(table_name, schema)
for index_d in indexes: for index_d in indexes:
name = index_d['name'] name = index_d['name']
columns = index_d['column_names'] columns = index_d['column_names']
unique = index_d['unique'] unique = index_d['unique']
flavor = index_d.get('type', 'unknown type') flavor = index_d.get('type', 'index')
dialect_options = index_d.get('dialect_options', {})
duplicates = index_d.get('duplicates_constraint')
if include_columns and \ if include_columns and \
not set(columns).issubset(include_columns): not set(columns).issubset(include_columns):
util.warn( util.warn(
"Omitting %s KEY for (%s), key covers omitted columns." % "Omitting %s key for (%s), key covers omitted columns." %
(flavor, ', '.join(columns))) (flavor, ', '.join(columns)))
continue continue
sa_schema.Index(name, *[table.columns[c] for c in columns], if duplicates:
**dict(unique=unique)) continue
# look for columns by orig name in cols_by_orig_name,
# but support columns that are in-Python only as fallback
idx_cols = []
for c in columns:
try:
idx_col = cols_by_orig_name[c] \
if c in cols_by_orig_name else table.c[c]
except KeyError:
util.warn(
"%s key '%s' was not located in "
"columns for table '%s'" % (
flavor, c, table_name
))
else:
idx_cols.append(idx_col)
sa_schema.Index(
name, *idx_cols,
**dict(list(dialect_options.items()) + [('unique', unique)])
)
def _reflect_unique_constraints(
self, table_name, schema, table, cols_by_orig_name,
include_columns, exclude_columns, reflection_options):
# Unique Constraints
try:
constraints = self.get_unique_constraints(table_name, schema)
except NotImplementedError:
# optional dialect feature
return
for const_d in constraints:
conname = const_d['name']
columns = const_d['column_names']
duplicates = const_d.get('duplicates_index')
if include_columns and \
not set(columns).issubset(include_columns):
util.warn(
"Omitting unique constraint key for (%s), "
"key covers omitted columns." %
', '.join(columns))
continue
if duplicates:
continue
# look for columns by orig name in cols_by_orig_name,
# but support columns that are in-Python only as fallback
constrained_cols = []
for c in columns:
try:
constrained_col = cols_by_orig_name[c] \
if c in cols_by_orig_name else table.c[c]
except KeyError:
util.warn(
"unique constraint key '%s' was not located in "
"columns for table '%s'" % (c, table_name))
else:
constrained_cols.append(constrained_col)
table.append_constraint(
sa_schema.UniqueConstraint(*constrained_cols, name=conname))
def _reflect_check_constraints(
self, table_name, schema, table, cols_by_orig_name,
include_columns, exclude_columns, reflection_options):
try:
constraints = self.get_check_constraints(table_name, schema)
except NotImplementedError:
# optional dialect feature
return
for const_d in constraints:
table.append_constraint(
sa_schema.CheckConstraint(**const_d))

View File

@ -1,3 +1,10 @@
# engine/strategies.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Strategies for creating new instances of Engine types. """Strategies for creating new instances of Engine types.
These are semi-private implementation classes which provide the These are semi-private implementation classes which provide the
@ -11,18 +18,19 @@ New strategies can be added via new ``EngineStrategy`` classes.
from operator import attrgetter from operator import attrgetter
from sqlalchemy.engine import base, threadlocal, url from sqlalchemy.engine import base, threadlocal, url
from sqlalchemy import util, exc from sqlalchemy import util, event
from sqlalchemy import pool as poollib from sqlalchemy import pool as poollib
from sqlalchemy.sql import schema
strategies = {} strategies = {}
class EngineStrategy(object): class EngineStrategy(object):
"""An adaptor that processes input arguements and produces an Engine. """An adaptor that processes input arguments and produces an Engine.
Provides a ``create`` method that receives input arguments and Provides a ``create`` method that receives input arguments and
produces an instance of base.Engine or a subclass. produces an instance of base.Engine or a subclass.
""" """
def __init__(self): def __init__(self):
@ -35,58 +43,75 @@ class EngineStrategy(object):
class DefaultEngineStrategy(EngineStrategy): class DefaultEngineStrategy(EngineStrategy):
"""Base class for built-in stratgies.""" """Base class for built-in strategies."""
pool_threadlocal = False
def create(self, name_or_url, **kwargs): def create(self, name_or_url, **kwargs):
# create url.URL object # create url.URL object
u = url.make_url(name_or_url) u = url.make_url(name_or_url)
dialect_cls = u.get_dialect() plugins = u._instantiate_plugins(kwargs)
u.query.pop('plugin', None)
entrypoint = u._get_entrypoint()
dialect_cls = entrypoint.get_dialect_cls(u)
if kwargs.pop('_coerce_config', False):
def pop_kwarg(key, default=None):
value = kwargs.pop(key, default)
if key in dialect_cls.engine_config_types:
value = dialect_cls.engine_config_types[key](value)
return value
else:
pop_kwarg = kwargs.pop
dialect_args = {} dialect_args = {}
# consume dialect arguments from kwargs # consume dialect arguments from kwargs
for k in util.get_cls_kwargs(dialect_cls): for k in util.get_cls_kwargs(dialect_cls):
if k in kwargs: if k in kwargs:
dialect_args[k] = kwargs.pop(k) dialect_args[k] = pop_kwarg(k)
dbapi = kwargs.pop('module', None) dbapi = kwargs.pop('module', None)
if dbapi is None: if dbapi is None:
dbapi_args = {} dbapi_args = {}
for k in util.get_func_kwargs(dialect_cls.dbapi): for k in util.get_func_kwargs(dialect_cls.dbapi):
if k in kwargs: if k in kwargs:
dbapi_args[k] = kwargs.pop(k) dbapi_args[k] = pop_kwarg(k)
dbapi = dialect_cls.dbapi(**dbapi_args) dbapi = dialect_cls.dbapi(**dbapi_args)
dialect_args['dbapi'] = dbapi dialect_args['dbapi'] = dbapi
for plugin in plugins:
plugin.handle_dialect_kwargs(dialect_cls, dialect_args)
# create dialect # create dialect
dialect = dialect_cls(**dialect_args) dialect = dialect_cls(**dialect_args)
# assemble connection arguments # assemble connection arguments
(cargs, cparams) = dialect.create_connect_args(u) (cargs, cparams) = dialect.create_connect_args(u)
cparams.update(kwargs.pop('connect_args', {})) cparams.update(pop_kwarg('connect_args', {}))
cargs = list(cargs) # allow mutability
# look for existing pool or create # look for existing pool or create
pool = kwargs.pop('pool', None) pool = pop_kwarg('pool', None)
if pool is None: if pool is None:
def connect(): def connect(connection_record=None):
try: if dialect._has_events:
return dialect.connect(*cargs, **cparams) for fn in dialect.dispatch.do_connect:
except Exception, e: connection = fn(
# Py3K dialect, connection_record, cargs, cparams)
#raise exc.DBAPIError.instance(None, None, e) from e if connection is not None:
# Py2K return connection
import sys return dialect.connect(*cargs, **cparams)
raise exc.DBAPIError.instance(None, None, e), None, sys.exc_info()[2]
# end Py2K
creator = kwargs.pop('creator', connect)
poolclass = (kwargs.pop('poolclass', None) or creator = pop_kwarg('creator', connect)
getattr(dialect_cls, 'poolclass', poollib.QueuePool))
pool_args = {} poolclass = pop_kwarg('poolclass', None)
if poolclass is None:
poolclass = dialect_cls.get_pool_class(u)
pool_args = {
'dialect': dialect
}
# consume pool arguments from kwargs, translating a few of # consume pool arguments from kwargs, translating a few of
# the arguments # the arguments
@ -94,12 +119,17 @@ class DefaultEngineStrategy(EngineStrategy):
'echo': 'echo_pool', 'echo': 'echo_pool',
'timeout': 'pool_timeout', 'timeout': 'pool_timeout',
'recycle': 'pool_recycle', 'recycle': 'pool_recycle',
'use_threadlocal':'pool_threadlocal'} 'events': 'pool_events',
'use_threadlocal': 'pool_threadlocal',
'reset_on_return': 'pool_reset_on_return'}
for k in util.get_cls_kwargs(poolclass): for k in util.get_cls_kwargs(poolclass):
tk = translate.get(k, k) tk = translate.get(k, k)
if tk in kwargs: if tk in kwargs:
pool_args[k] = kwargs.pop(tk) pool_args[k] = pop_kwarg(tk)
pool_args.setdefault('use_threadlocal', self.pool_threadlocal)
for plugin in plugins:
plugin.handle_pool_kwargs(poolclass, pool_args)
pool = poolclass(creator, **pool_args) pool = poolclass(creator, **pool_args)
else: else:
if isinstance(pool, poollib._DBProxy): if isinstance(pool, poollib._DBProxy):
@ -107,15 +137,17 @@ class DefaultEngineStrategy(EngineStrategy):
else: else:
pool = pool pool = pool
pool._dialect = dialect
# create engine. # create engine.
engineclass = self.engine_cls engineclass = self.engine_cls
engine_args = {} engine_args = {}
for k in util.get_cls_kwargs(engineclass): for k in util.get_cls_kwargs(engineclass):
if k in kwargs: if k in kwargs:
engine_args[k] = kwargs.pop(k) engine_args[k] = pop_kwarg(k)
_initialize = kwargs.pop('_initialize', True) _initialize = kwargs.pop('_initialize', True)
# all kwargs should be consumed # all kwargs should be consumed
if kwargs: if kwargs:
raise TypeError( raise TypeError(
@ -126,24 +158,35 @@ class DefaultEngineStrategy(EngineStrategy):
dialect.__class__.__name__, dialect.__class__.__name__,
pool.__class__.__name__, pool.__class__.__name__,
engineclass.__name__)) engineclass.__name__))
engine = engineclass(pool, dialect, u, **engine_args) engine = engineclass(pool, dialect, u, **engine_args)
if _initialize: if _initialize:
do_on_connect = dialect.on_connect() do_on_connect = dialect.on_connect()
if do_on_connect: if do_on_connect:
def on_connect(conn, rec): def on_connect(dbapi_connection, connection_record):
conn = getattr(conn, '_sqla_unwrap', conn) conn = getattr(
dbapi_connection, '_sqla_unwrap', dbapi_connection)
if conn is None: if conn is None:
return return
do_on_connect(conn) do_on_connect(conn)
pool.add_listener({'first_connect': on_connect, 'connect':on_connect}) event.listen(pool, 'first_connect', on_connect)
event.listen(pool, 'connect', on_connect)
def first_connect(conn, rec):
c = base.Connection(engine, connection=conn) def first_connect(dbapi_connection, connection_record):
c = base.Connection(engine, connection=dbapi_connection,
_has_events=False)
c._execution_options = util.immutabledict()
dialect.initialize(c) dialect.initialize(c)
pool.add_listener({'first_connect':first_connect}) event.listen(pool, 'first_connect', first_connect, once=True)
dialect_cls.engine_created(engine)
if entrypoint is not dialect_cls:
entrypoint.engine_created(engine)
for plugin in plugins:
plugin.engine_created(engine)
return engine return engine
@ -153,15 +196,14 @@ class PlainEngineStrategy(DefaultEngineStrategy):
name = 'plain' name = 'plain'
engine_cls = base.Engine engine_cls = base.Engine
PlainEngineStrategy() PlainEngineStrategy()
class ThreadLocalEngineStrategy(DefaultEngineStrategy): class ThreadLocalEngineStrategy(DefaultEngineStrategy):
"""Strategy for configuring an Engine with thredlocal behavior.""" """Strategy for configuring an Engine with threadlocal behavior."""
name = 'threadlocal' name = 'threadlocal'
pool_threadlocal = True
engine_cls = threadlocal.TLEngine engine_cls = threadlocal.TLEngine
ThreadLocalEngineStrategy() ThreadLocalEngineStrategy()
@ -172,11 +214,11 @@ class MockEngineStrategy(EngineStrategy):
Produces a single mock Connectable object which dispatches Produces a single mock Connectable object which dispatches
statement execution to a passed-in function. statement execution to a passed-in function.
""" """
name = 'mock' name = 'mock'
def create(self, name_or_url, executor, **kwargs): def create(self, name_or_url, executor, **kwargs):
# create url.URL object # create url.URL object
u = url.make_url(name_or_url) u = url.make_url(name_or_url)
@ -203,9 +245,14 @@ class MockEngineStrategy(EngineStrategy):
dialect = property(attrgetter('_dialect')) dialect = property(attrgetter('_dialect'))
name = property(lambda s: s._dialect.name) name = property(lambda s: s._dialect.name)
schema_for_object = schema._schema_getter(None)
def contextual_connect(self, **kwargs): def contextual_connect(self, **kwargs):
return self return self
def execution_options(self, **kw):
return self
def compiler(self, statement, parameters, **kwargs): def compiler(self, statement, parameters, **kwargs):
return self._dialect.compiler( return self._dialect.compiler(
statement, parameters, engine=self, **kwargs) statement, parameters, engine=self, **kwargs)
@ -213,13 +260,22 @@ class MockEngineStrategy(EngineStrategy):
def create(self, entity, **kwargs): def create(self, entity, **kwargs):
kwargs['checkfirst'] = False kwargs['checkfirst'] = False
from sqlalchemy.engine import ddl from sqlalchemy.engine import ddl
ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity) ddl.SchemaGenerator(
self.dialect, self, **kwargs).traverse_single(entity)
def drop(self, entity, **kwargs): def drop(self, entity, **kwargs):
kwargs['checkfirst'] = False kwargs['checkfirst'] = False
from sqlalchemy.engine import ddl from sqlalchemy.engine import ddl
ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity) ddl.SchemaDropper(
self.dialect, self, **kwargs).traverse_single(entity)
def _run_visitor(self, visitorcallable, element,
connection=None,
**kwargs):
kwargs['checkfirst'] = False
visitorcallable(self.dialect, self,
**kwargs).traverse_single(element)
def execute(self, object, *multiparams, **params): def execute(self, object, *multiparams, **params):
raise NotImplementedError() raise NotImplementedError()

View File

@ -1,23 +1,33 @@
# engine/threadlocal.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Provides a thread-local transactional wrapper around the root Engine class. """Provides a thread-local transactional wrapper around the root Engine class.
The ``threadlocal`` module is invoked when using the ``strategy="threadlocal"`` flag The ``threadlocal`` module is invoked when using the
with :func:`~sqlalchemy.engine.create_engine`. This module is semi-private and is ``strategy="threadlocal"`` flag with :func:`~sqlalchemy.engine.create_engine`.
invoked automatically when the threadlocal engine strategy is used. This module is semi-private and is invoked automatically when the threadlocal
engine strategy is used.
""" """
from sqlalchemy import util from .. import util
from sqlalchemy.engine import base from . import base
import weakref import weakref
class TLConnection(base.Connection): class TLConnection(base.Connection):
def __init__(self, *arg, **kw): def __init__(self, *arg, **kw):
super(TLConnection, self).__init__(*arg, **kw) super(TLConnection, self).__init__(*arg, **kw)
self.__opencount = 0 self.__opencount = 0
def _increment_connect(self): def _increment_connect(self):
self.__opencount += 1 self.__opencount += 1
return self return self
def close(self): def close(self):
if self.__opencount == 1: if self.__opencount == 1:
base.Connection.close(self) base.Connection.close(self)
@ -27,70 +37,95 @@ class TLConnection(base.Connection):
self.__opencount = 0 self.__opencount = 0
base.Connection.close(self) base.Connection.close(self)
class TLEngine(base.Engine):
"""An Engine that includes support for thread-local managed transactions."""
class TLEngine(base.Engine):
"""An Engine that includes support for thread-local managed
transactions.
"""
_tl_connection_cls = TLConnection
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TLEngine, self).__init__(*args, **kwargs) super(TLEngine, self).__init__(*args, **kwargs)
self._connections = util.threading.local() self._connections = util.threading.local()
proxy = kwargs.get('proxy')
if proxy:
self.TLConnection = base._proxy_connection_cls(TLConnection, proxy)
else:
self.TLConnection = TLConnection
def contextual_connect(self, **kw): def contextual_connect(self, **kw):
if not hasattr(self._connections, 'conn'): if not hasattr(self._connections, 'conn'):
connection = None connection = None
else: else:
connection = self._connections.conn() connection = self._connections.conn()
if connection is None or connection.closed: if connection is None or connection.closed:
# guards against pool-level reapers, if desired. # guards against pool-level reapers, if desired.
# or not connection.connection.is_valid: # or not connection.connection.is_valid:
connection = self.TLConnection(self, self.pool.connect(), **kw) connection = self._tl_connection_cls(
self._connections.conn = conn = weakref.ref(connection) self,
self._wrap_pool_connect(
self.pool.connect, connection),
**kw)
self._connections.conn = weakref.ref(connection)
return connection._increment_connect() return connection._increment_connect()
def begin_twophase(self, xid=None): def begin_twophase(self, xid=None):
if not hasattr(self._connections, 'trans'): if not hasattr(self._connections, 'trans'):
self._connections.trans = [] self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid)) self._connections.trans.append(
self.contextual_connect().begin_twophase(xid=xid))
return self
def begin_nested(self): def begin_nested(self):
if not hasattr(self._connections, 'trans'): if not hasattr(self._connections, 'trans'):
self._connections.trans = [] self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin_nested()) self._connections.trans.append(
self.contextual_connect().begin_nested())
return self
def begin(self): def begin(self):
if not hasattr(self._connections, 'trans'): if not hasattr(self._connections, 'trans'):
self._connections.trans = [] self._connections.trans = []
self._connections.trans.append(self.contextual_connect().begin()) self._connections.trans.append(self.contextual_connect().begin())
return self
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
if type is None:
self.commit()
else:
self.rollback()
def prepare(self): def prepare(self):
if not hasattr(self._connections, 'trans') or \
not self._connections.trans:
return
self._connections.trans[-1].prepare() self._connections.trans[-1].prepare()
def commit(self): def commit(self):
if not hasattr(self._connections, 'trans') or \
not self._connections.trans:
return
trans = self._connections.trans.pop(-1) trans = self._connections.trans.pop(-1)
trans.commit() trans.commit()
def rollback(self): def rollback(self):
if not hasattr(self._connections, 'trans') or \
not self._connections.trans:
return
trans = self._connections.trans.pop(-1) trans = self._connections.trans.pop(-1)
trans.rollback() trans.rollback()
def dispose(self): def dispose(self):
self._connections = util.threading.local() self._connections = util.threading.local()
super(TLEngine, self).dispose() super(TLEngine, self).dispose()
@property @property
def closed(self): def closed(self):
return not hasattr(self._connections, 'conn') or \ return not hasattr(self._connections, 'conn') or \
self._connections.conn() is None or \ self._connections.conn() is None or \
self._connections.conn().closed self._connections.conn().closed
def close(self): def close(self):
if not self.closed: if not self.closed:
self.contextual_connect().close() self.contextual_connect().close()
@ -98,6 +133,6 @@ class TLEngine(base.Engine):
connection._force_close() connection._force_close()
del self._connections.conn del self._connections.conn
self._connections.trans = [] self._connections.trans = []
def __repr__(self): def __repr__(self):
return 'TLEngine(%s)' % str(self.url) return 'TLEngine(%r)' % self.url

View File

@ -1,13 +1,23 @@
# engine/url.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates """Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates
information about a database connection specification. information about a database connection specification.
The URL object is created automatically when :func:`~sqlalchemy.engine.create_engine` is called The URL object is created automatically when
with a string argument; alternatively, the URL is a public-facing construct which can :func:`~sqlalchemy.engine.create_engine` is called with a string
argument; alternatively, the URL is a public-facing construct which can
be used directly and is also accepted directly by ``create_engine()``. be used directly and is also accepted directly by ``create_engine()``.
""" """
import re, cgi, sys, urllib import re
from sqlalchemy import exc from .. import exc, util
from . import Dialect
from ..dialects import registry, plugins
class URL(object): class URL(object):
@ -15,8 +25,8 @@ class URL(object):
Represent the components of a URL used to connect to a database. Represent the components of a URL used to connect to a database.
This object is suitable to be passed directly to a This object is suitable to be passed directly to a
``create_engine()`` call. The fields of the URL are parsed from a :func:`~sqlalchemy.create_engine` call. The fields of the URL are parsed
string by the ``module-level make_url()`` function. the string from a string by the :func:`.make_url` function. the string
format of the URL is an RFC-1738-style string. format of the URL is an RFC-1738-style string.
All initialization parameters are available as public attributes. All initialization parameters are available as public attributes.
@ -53,25 +63,35 @@ class URL(object):
self.database = database self.database = database
self.query = query or {} self.query = query or {}
def __str__(self): def __to_string__(self, hide_password=True):
s = self.drivername + "://" s = self.drivername + "://"
if self.username is not None: if self.username is not None:
s += self.username s += _rfc_1738_quote(self.username)
if self.password is not None: if self.password is not None:
s += ':' + urllib.quote_plus(self.password) s += ':' + ('***' if hide_password
else _rfc_1738_quote(self.password))
s += "@" s += "@"
if self.host is not None: if self.host is not None:
s += self.host if ':' in self.host:
s += "[%s]" % self.host
else:
s += self.host
if self.port is not None: if self.port is not None:
s += ':' + str(self.port) s += ':' + str(self.port)
if self.database is not None: if self.database is not None:
s += '/' + self.database s += '/' + self.database
if self.query: if self.query:
keys = self.query.keys() keys = list(self.query)
keys.sort() keys.sort()
s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys) s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys)
return s return s
def __str__(self):
return self.__to_string__(hide_password=False)
def __repr__(self):
return self.__to_string__()
def __hash__(self): def __hash__(self):
return hash(str(self)) return hash(str(self))
@ -85,49 +105,58 @@ class URL(object):
self.database == other.database and \ self.database == other.database and \
self.query == other.query self.query == other.query
def get_backend_name(self):
if '+' not in self.drivername:
return self.drivername
else:
return self.drivername.split('+')[0]
def get_driver_name(self):
if '+' not in self.drivername:
return self.get_dialect().driver
else:
return self.drivername.split('+')[1]
def _instantiate_plugins(self, kwargs):
plugin_names = util.to_list(self.query.get('plugin', ()))
return [
plugins.load(plugin_name)(self, kwargs)
for plugin_name in plugin_names
]
def _get_entrypoint(self):
"""Return the "entry point" dialect class.
This is normally the dialect itself except in the case when the
returned class implements the get_dialect_cls() method.
"""
if '+' not in self.drivername:
name = self.drivername
else:
name = self.drivername.replace('+', '.')
cls = registry.load(name)
# check for legacy dialects that
# would return a module with 'dialect' as the
# actual class
if hasattr(cls, 'dialect') and \
isinstance(cls.dialect, type) and \
issubclass(cls.dialect, Dialect):
return cls.dialect
else:
return cls
def get_dialect(self): def get_dialect(self):
"""Return the SQLAlchemy database dialect class corresponding """Return the SQLAlchemy database dialect class corresponding
to this URL's driver name. to this URL's driver name.
""" """
entrypoint = self._get_entrypoint()
dialect_cls = entrypoint.get_dialect_cls(self)
return dialect_cls
try:
if '+' in self.drivername:
dialect, driver = self.drivername.split('+')
else:
dialect, driver = self.drivername, 'base'
module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
module = getattr(module, dialect)
module = getattr(module, driver)
return module.dialect
except ImportError:
module = self._load_entry_point()
if module is not None:
return module
else:
raise
def _load_entry_point(self):
"""attempt to load this url's dialect from entry points, or return None
if pkg_resources is not installed or there is no matching entry point.
Raise ImportError if the actual load fails.
"""
try:
import pkg_resources
except ImportError:
return None
for res in pkg_resources.iter_entry_points('sqlalchemy.dialects'):
if res.name == self.drivername:
return res.load()
else:
return None
def translate_connect_args(self, names=[], **kw): def translate_connect_args(self, names=[], **kw):
"""Translate url attributes into a dictionary of connection arguments. r"""Translate url attributes into a dictionary of connection arguments.
Returns attributes of this url (`host`, `database`, `username`, Returns attributes of this url (`host`, `database`, `username`,
`password`, `port`) as a plain dictionary. The attribute names are `password`, `port`) as a plain dictionary. The attribute names are
@ -136,8 +165,8 @@ class URL(object):
:param \**kw: Optional, alternate key names for url attributes. :param \**kw: Optional, alternate key names for url attributes.
:param names: Deprecated. Same purpose as the keyword-based alternate names, :param names: Deprecated. Same purpose as the keyword-based alternate
but correlates the name to the original positionally. names, but correlates the name to the original positionally.
""" """
translated = {} translated = {}
@ -153,6 +182,7 @@ class URL(object):
translated[name] = getattr(self, sname) translated[name] = getattr(self, sname)
return translated return translated
def make_url(name_or_url): def make_url(name_or_url):
"""Given a string or unicode instance, produce a new URL instance. """Given a string or unicode instance, produce a new URL instance.
@ -160,25 +190,28 @@ def make_url(name_or_url):
existing URL object is passed, just returns the object. existing URL object is passed, just returns the object.
""" """
if isinstance(name_or_url, basestring): if isinstance(name_or_url, util.string_types):
return _parse_rfc1738_args(name_or_url) return _parse_rfc1738_args(name_or_url)
else: else:
return name_or_url return name_or_url
def _parse_rfc1738_args(name): def _parse_rfc1738_args(name):
pattern = re.compile(r''' pattern = re.compile(r'''
(?P<name>[\w\+]+):// (?P<name>[\w\+]+)://
(?: (?:
(?P<username>[^:/]*) (?P<username>[^:/]*)
(?::(?P<password>[^/]*))? (?::(?P<password>.*))?
@)? @)?
(?: (?:
(?P<host>[^/:]*) (?:
\[(?P<ipv6host>[^/]+)\] |
(?P<ipv4host>[^/:]+)
)?
(?::(?P<port>[^/]*))? (?::(?P<port>[^/]*))?
)? )?
(?:/(?P<database>.*))? (?:/(?P<database>.*))?
''' ''', re.X)
, re.X)
m = pattern.match(name) m = pattern.match(name)
if m is not None: if m is not None:
@ -186,29 +219,43 @@ def _parse_rfc1738_args(name):
if components['database'] is not None: if components['database'] is not None:
tokens = components['database'].split('?', 2) tokens = components['database'].split('?', 2)
components['database'] = tokens[0] components['database'] = tokens[0]
query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None query = (
# Py2K len(tokens) > 1 and dict(util.parse_qsl(tokens[1]))) or None
if query is not None: if util.py2k and query is not None:
query = dict((k.encode('ascii'), query[k]) for k in query) query = dict((k.encode('ascii'), query[k]) for k in query)
# end Py2K
else: else:
query = None query = None
components['query'] = query components['query'] = query
if components['password'] is not None: if components['username'] is not None:
components['password'] = urllib.unquote_plus(components['password']) components['username'] = _rfc_1738_unquote(components['username'])
if components['password'] is not None:
components['password'] = _rfc_1738_unquote(components['password'])
ipv4host = components.pop('ipv4host')
ipv6host = components.pop('ipv6host')
components['host'] = ipv4host or ipv6host
name = components.pop('name') name = components.pop('name')
return URL(name, **components) return URL(name, **components)
else: else:
raise exc.ArgumentError( raise exc.ArgumentError(
"Could not parse rfc1738 URL from string '%s'" % name) "Could not parse rfc1738 URL from string '%s'" % name)
def _rfc_1738_quote(text):
return re.sub(r'[:@/]', lambda m: "%%%X" % ord(m.group(0)), text)
def _rfc_1738_unquote(text):
return util.unquote(text)
def _parse_keyvalue_args(name): def _parse_keyvalue_args(name):
m = re.match( r'(\w+)://(.*)', name) m = re.match(r'(\w+)://(.*)', name)
if m is not None: if m is not None:
(name, args) = m.group(1, 2) (name, args) = m.group(1, 2)
opts = dict( cgi.parse_qsl( args ) ) opts = dict(util.parse_qsl(args))
return URL(name, *opts) return URL(name, *opts)
else: else:
return None return None

View File

@ -1,13 +1,15 @@
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com # sqlalchemy/exc.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Exceptions used with SQLAlchemy. """Exceptions used with SQLAlchemy.
The base exception class is SQLAlchemyError. Exceptions which are raised as a The base exception class is :exc:`.SQLAlchemyError`. Exceptions which are
result of DBAPI exceptions are all subclasses of raised as a result of DBAPI exceptions are all subclasses of
:class:`~sqlalchemy.exc.DBAPIError`. :exc:`.DBAPIError`.
""" """
@ -24,31 +26,100 @@ class ArgumentError(SQLAlchemyError):
""" """
class ObjectNotExecutableError(ArgumentError):
"""Raised when an object is passed to .execute() that can't be
executed as SQL.
.. versionadded:: 1.1
"""
def __init__(self, target):
super(ObjectNotExecutableError, self).__init__(
"Not an executable object: %r" % target
)
class NoSuchModuleError(ArgumentError):
"""Raised when a dynamically-loaded module (usually a database dialect)
of a particular name cannot be located."""
class NoForeignKeysError(ArgumentError):
"""Raised when no foreign keys can be located between two selectables
during a join."""
class AmbiguousForeignKeysError(ArgumentError):
"""Raised when more than one foreign key matching can be located
between two selectables during a join."""
class CircularDependencyError(SQLAlchemyError): class CircularDependencyError(SQLAlchemyError):
"""Raised by topological sorts when a circular dependency is detected""" """Raised by topological sorts when a circular dependency is detected.
There are two scenarios where this error occurs:
* In a Session flush operation, if two objects are mutually dependent
on each other, they can not be inserted or deleted via INSERT or
DELETE statements alone; an UPDATE will be needed to post-associate
or pre-deassociate one of the foreign key constrained values.
The ``post_update`` flag described at :ref:`post_update` can resolve
this cycle.
* In a :attr:`.MetaData.sorted_tables` operation, two :class:`.ForeignKey`
or :class:`.ForeignKeyConstraint` objects mutually refer to each
other. Apply the ``use_alter=True`` flag to one or both,
see :ref:`use_alter`.
"""
def __init__(self, message, cycles, edges, msg=None):
if msg is None:
message += " (%s)" % ", ".join(repr(s) for s in cycles)
else:
message = msg
SQLAlchemyError.__init__(self, message)
self.cycles = cycles
self.edges = edges
def __reduce__(self):
return self.__class__, (None, self.cycles,
self.edges, self.args[0])
class CompileError(SQLAlchemyError): class CompileError(SQLAlchemyError):
"""Raised when an error occurs during SQL compilation""" """Raised when an error occurs during SQL compilation"""
class UnsupportedCompilationError(CompileError):
"""Raised when an operation is not supported by the given compiler.
.. versionadded:: 0.8.3
"""
def __init__(self, compiler, element_type):
super(UnsupportedCompilationError, self).__init__(
"Compiler %r can't render element of type %s" %
(compiler, element_type))
class IdentifierError(SQLAlchemyError): class IdentifierError(SQLAlchemyError):
"""Raised when a schema name is beyond the max character limit""" """Raised when a schema name is beyond the max character limit"""
# Moved to orm.exc; compatability definition installed by orm import until 0.6
ConcurrentModificationError = None
class DisconnectionError(SQLAlchemyError): class DisconnectionError(SQLAlchemyError):
"""A disconnect is detected on a raw DB-API connection. """A disconnect is detected on a raw DB-API connection.
This error is raised and consumed internally by a connection pool. It can This error is raised and consumed internally by a connection pool. It can
be raised by a ``PoolListener`` so that the host pool forces a disconnect. be raised by the :meth:`.PoolEvents.checkout` event so that the host pool
forces a retry; the exception will be caught three times in a row before
the pool gives up and raises :class:`~sqlalchemy.exc.InvalidRequestError`
regarding the connection attempt.
""" """
# Moved to orm.exc; compatability definition installed by orm import until 0.6
FlushError = None
class TimeoutError(SQLAlchemyError): class TimeoutError(SQLAlchemyError):
"""Raised when a connection pool times out on getting a connection.""" """Raised when a connection pool times out on getting a connection."""
@ -60,17 +131,52 @@ class InvalidRequestError(SQLAlchemyError):
""" """
class NoInspectionAvailable(InvalidRequestError):
"""A subject passed to :func:`sqlalchemy.inspection.inspect` produced
no context for inspection."""
class ResourceClosedError(InvalidRequestError):
"""An operation was requested from a connection, cursor, or other
object that's in a closed state."""
class NoSuchColumnError(KeyError, InvalidRequestError): class NoSuchColumnError(KeyError, InvalidRequestError):
"""A nonexistent column is requested from a ``RowProxy``.""" """A nonexistent column is requested from a ``RowProxy``."""
class NoReferenceError(InvalidRequestError): class NoReferenceError(InvalidRequestError):
"""Raised by ``ForeignKey`` to indicate a reference cannot be resolved.""" """Raised by ``ForeignKey`` to indicate a reference cannot be resolved."""
class NoReferencedTableError(NoReferenceError): class NoReferencedTableError(NoReferenceError):
"""Raised by ``ForeignKey`` when the referred ``Table`` cannot be located.""" """Raised by ``ForeignKey`` when the referred ``Table`` cannot be
located.
"""
def __init__(self, message, tname):
NoReferenceError.__init__(self, message)
self.table_name = tname
def __reduce__(self):
return self.__class__, (self.args[0], self.table_name)
class NoReferencedColumnError(NoReferenceError): class NoReferencedColumnError(NoReferenceError):
"""Raised by ``ForeignKey`` when the referred ``Column`` cannot be located.""" """Raised by ``ForeignKey`` when the referred ``Column`` cannot be
located.
"""
def __init__(self, message, tname, cname):
NoReferenceError.__init__(self, message)
self.table_name = tname
self.column_name = cname
def __reduce__(self):
return self.__class__, (self.args[0], self.table_name,
self.column_name)
class NoSuchTableError(InvalidRequestError): class NoSuchTableError(InvalidRequestError):
"""Table does not exist or is not visible to a connection.""" """Table does not exist or is not visible to a connection."""
@ -80,70 +186,161 @@ class UnboundExecutionError(InvalidRequestError):
"""SQL was attempted without a database connection to execute it on.""" """SQL was attempted without a database connection to execute it on."""
# Moved to orm.exc; compatability definition installed by orm import until 0.6 class DontWrapMixin(object):
"""A mixin class which, when applied to a user-defined Exception class,
will not be wrapped inside of :exc:`.StatementError` if the error is
emitted within the process of executing a statement.
E.g.::
from sqlalchemy.exc import DontWrapMixin
class MyCustomException(Exception, DontWrapMixin):
pass
class MySpecialType(TypeDecorator):
impl = String
def process_bind_param(self, value, dialect):
if value == 'invalid':
raise MyCustomException("invalid!")
"""
# Moved to orm.exc; compatibility definition installed by orm import until 0.6
UnmappedColumnError = None UnmappedColumnError = None
class DBAPIError(SQLAlchemyError):
class StatementError(SQLAlchemyError):
"""An error occurred during execution of a SQL statement.
:class:`StatementError` wraps the exception raised
during execution, and features :attr:`.statement`
and :attr:`.params` attributes which supply context regarding
the specifics of the statement which had an issue.
The wrapped exception object is available in
the :attr:`.orig` attribute.
"""
statement = None
"""The string SQL statement being invoked when this exception occurred."""
params = None
"""The parameter list being used when this exception occurred."""
orig = None
"""The DBAPI exception object."""
def __init__(self, message, statement, params, orig):
SQLAlchemyError.__init__(self, message)
self.statement = statement
self.params = params
self.orig = orig
self.detail = []
def add_detail(self, msg):
self.detail.append(msg)
def __reduce__(self):
return self.__class__, (self.args[0], self.statement,
self.params, self.orig)
def __str__(self):
from sqlalchemy.sql import util
details = [SQLAlchemyError.__str__(self)]
if self.statement:
details.append("[SQL: %r]" % self.statement)
if self.params:
params_repr = util._repr_params(self.params, 10)
details.append("[parameters: %r]" % params_repr)
return ' '.join([
"(%s)" % det for det in self.detail
] + details)
def __unicode__(self):
return self.__str__()
class DBAPIError(StatementError):
"""Raised when the execution of a database operation fails. """Raised when the execution of a database operation fails.
``DBAPIError`` wraps exceptions raised by the DB-API underlying the Wraps exceptions raised by the DB-API underlying the
database operation. Driver-specific implementations of the standard database operation. Driver-specific implementations of the standard
DB-API exception types are wrapped by matching sub-types of SQLAlchemy's DB-API exception types are wrapped by matching sub-types of SQLAlchemy's
``DBAPIError`` when possible. DB-API's ``Error`` type maps to :class:`DBAPIError` when possible. DB-API's ``Error`` type maps to
``DBAPIError`` in SQLAlchemy, otherwise the names are identical. Note :class:`DBAPIError` in SQLAlchemy, otherwise the names are identical. Note
that there is no guarantee that different DB-API implementations will that there is no guarantee that different DB-API implementations will
raise the same exception type for any given error condition. raise the same exception type for any given error condition.
If the error-raising operation occured in the execution of a SQL :class:`DBAPIError` features :attr:`~.StatementError.statement`
statement, that statement and its parameters will be available on and :attr:`~.StatementError.params` attributes which supply context
the exception object in the ``statement`` and ``params`` attributes. regarding the specifics of the statement which had an issue, for the
typical case when the error was raised within the context of
emitting a SQL statement.
The wrapped exception object is available in the ``orig`` attribute. The wrapped exception object is available in the
Its type and properties are DB-API implementation specific. :attr:`~.StatementError.orig` attribute. Its type and properties are
DB-API implementation specific.
""" """
@classmethod @classmethod
def instance(cls, statement, params, orig, connection_invalidated=False): def instance(cls, statement, params,
orig, dbapi_base_err,
connection_invalidated=False,
dialect=None):
# Don't ever wrap these, just return them directly as if # Don't ever wrap these, just return them directly as if
# DBAPIError didn't exist. # DBAPIError didn't exist.
if isinstance(orig, (KeyboardInterrupt, SystemExit)): if (isinstance(orig, BaseException) and
not isinstance(orig, Exception)) or \
isinstance(orig, DontWrapMixin):
return orig return orig
if orig is not None: if orig is not None:
name, glob = orig.__class__.__name__, globals() # not a DBAPI error, statement is present.
if name in glob and issubclass(glob[name], DBAPIError): # raise a StatementError
cls = glob[name] if not isinstance(orig, dbapi_base_err) and statement:
return StatementError(
"(%s.%s) %s" %
(orig.__class__.__module__, orig.__class__.__name__,
orig),
statement, params, orig
)
glob = globals()
for super_ in orig.__class__.__mro__:
name = super_.__name__
if dialect:
name = dialect.dbapi_exception_translation_map.get(
name, name)
if name in glob and issubclass(glob[name], DBAPIError):
cls = glob[name]
break
return cls(statement, params, orig, connection_invalidated) return cls(statement, params, orig, connection_invalidated)
def __reduce__(self):
return self.__class__, (self.statement, self.params,
self.orig, self.connection_invalidated)
def __init__(self, statement, params, orig, connection_invalidated=False): def __init__(self, statement, params, orig, connection_invalidated=False):
try: try:
text = str(orig) text = str(orig)
except (KeyboardInterrupt, SystemExit): except Exception as e:
raise
except Exception, e:
text = 'Error in str() of DB-API-generated exception: ' + str(e) text = 'Error in str() of DB-API-generated exception: ' + str(e)
SQLAlchemyError.__init__( StatementError.__init__(
self, '(%s) %s' % (orig.__class__.__name__, text)) self,
self.statement = statement '(%s.%s) %s' % (
self.params = params orig.__class__.__module__, orig.__class__.__name__, text, ),
self.orig = orig statement,
params,
orig
)
self.connection_invalidated = connection_invalidated self.connection_invalidated = connection_invalidated
def __str__(self):
if isinstance(self.params, (list, tuple)) and len(self.params) > 10 and isinstance(self.params[0], (list, dict, tuple)):
return ' '.join((SQLAlchemyError.__str__(self),
repr(self.statement),
repr(self.params[:2]),
'... and a total of %i bound parameter sets' % len(self.params)))
return ' '.join((SQLAlchemyError.__str__(self),
repr(self.statement), repr(self.params)))
# As of 0.4, SQLError is now DBAPIError.
# SQLError alias will be removed in 0.6.
SQLError = DBAPIError
class InterfaceError(DBAPIError): class InterfaceError(DBAPIError):
"""Wraps a DB-API InterfaceError.""" """Wraps a DB-API InterfaceError."""

View File

@ -1 +1,11 @@
# ext/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .. import util as _sa_util
_sa_util.dependencies.resolve_all("sqlalchemy.ext")

View File

@ -1,3 +1,10 @@
# ext/associationproxy.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Contain the ``AssociationProxy`` class. """Contain the ``AssociationProxy`` class.
The ``AssociationProxy`` is a Python property object which provides The ``AssociationProxy`` is a Python property object which provides
@ -9,43 +16,37 @@ See the example ``examples/association/proxied_association.py``.
import itertools import itertools
import operator import operator
import weakref import weakref
from sqlalchemy import exceptions from .. import exc, orm, util
from sqlalchemy import orm from ..orm import collections, interfaces
from sqlalchemy import util from ..sql import not_, or_
from sqlalchemy.orm import collections
from sqlalchemy.sql import not_
def association_proxy(target_collection, attr, **kw): def association_proxy(target_collection, attr, **kw):
"""Return a Python property implementing a view of *attr* over a collection. r"""Return a Python property implementing a view of a target
attribute which references an attribute on members of the
target.
Implements a read/write view over an instance's *target_collection*, The returned value is an instance of :class:`.AssociationProxy`.
extracting *attr* from each member of the collection. The property acts
somewhat like this list comprehension::
[getattr(member, *attr*) Implements a Python property representing a relationship as a collection
for member in getattr(instance, *target_collection*)] of simpler values, or a scalar value. The proxied property will mimic
the collection type of the target (list, dict or set), or, in the case of
a one to one relationship, a simple scalar value.
Unlike the list comprehension, the collection returned by the property is :param target_collection: Name of the attribute we'll proxy to.
always in sync with *target_collection*, and mutations made to either This attribute is typically mapped by
collection will be reflected in both. :func:`~sqlalchemy.orm.relationship` to link to a target collection, but
can also be a many-to-one or non-scalar relationship.
Implements a Python property representing a relationship as a collection of :param attr: Attribute on the associated instance or instances we'll
simpler values. The proxied property will mimic the collection type of proxy for.
the target (list, dict or set), or, in the case of a one to one relationship,
a simple scalar value.
:param target_collection: Name of the relationship attribute we'll proxy to,
usually created with :func:`~sqlalchemy.orm.relationship`.
:param attr: Attribute on the associated instances we'll proxy for.
For example, given a target collection of [obj1, obj2], a list created For example, given a target collection of [obj1, obj2], a list created
by this proxy property would look like [getattr(obj1, *attr*), by this proxy property would look like [getattr(obj1, *attr*),
getattr(obj2, *attr*)] getattr(obj2, *attr*)]
If the relationship is one-to-one or otherwise uselist=False, then simply: If the relationship is one-to-one or otherwise uselist=False, then
getattr(obj, *attr*) simply: getattr(obj, *attr*)
:param creator: optional. :param creator: optional.
@ -69,59 +70,78 @@ def association_proxy(target_collection, attr, **kw):
situation. situation.
:param \*\*kw: Passes along any other keyword arguments to :param \*\*kw: Passes along any other keyword arguments to
:class:`AssociationProxy`. :class:`.AssociationProxy`.
""" """
return AssociationProxy(target_collection, attr, **kw) return AssociationProxy(target_collection, attr, **kw)
class AssociationProxy(object): ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY')
"""Symbol indicating an :class:`InspectionAttr` that's
of type :class:`.AssociationProxy`.
Is assigned to the :attr:`.InspectionAttr.extension_type`
attibute.
"""
class AssociationProxy(interfaces.InspectionAttrInfo):
"""A descriptor that presents a read/write view of an object attribute.""" """A descriptor that presents a read/write view of an object attribute."""
is_attribute = False
extension_type = ASSOCIATION_PROXY
def __init__(self, target_collection, attr, creator=None, def __init__(self, target_collection, attr, creator=None,
getset_factory=None, proxy_factory=None, proxy_bulk_set=None): getset_factory=None, proxy_factory=None,
"""Arguments are: proxy_bulk_set=None, info=None):
"""Construct a new :class:`.AssociationProxy`.
target_collection The :func:`.association_proxy` function is provided as the usual
Name of the collection we'll proxy to, usually created with entrypoint here, though :class:`.AssociationProxy` can be instantiated
'relationship()' in a mapper setup. and/or subclassed directly.
attr :param target_collection: Name of the collection we'll proxy to,
Attribute on the collected instances we'll proxy for. For example, usually created with :func:`.relationship`.
given a target collection of [obj1, obj2], a list created by this
proxy property would look like [getattr(obj1, attr), getattr(obj2,
attr)]
creator :param attr: Attribute on the collected instances we'll proxy
Optional. When new items are added to this proxied collection, new for. For example, given a target collection of [obj1, obj2], a
instances of the class collected by the target collection will be list created by this proxy property would look like
created. For list and set collections, the target class constructor [getattr(obj1, attr), getattr(obj2, attr)]
will be called with the 'value' for the new instance. For dict
types, two arguments are passed: key and value. :param creator: Optional. When new items are added to this proxied
collection, new instances of the class collected by the target
collection will be created. For list and set collections, the
target class constructor will be called with the 'value' for the
new instance. For dict types, two arguments are passed:
key and value.
If you want to construct instances differently, supply a 'creator' If you want to construct instances differently, supply a 'creator'
function that takes arguments as above and returns instances. function that takes arguments as above and returns instances.
getset_factory :param getset_factory: Optional. Proxied attribute access is
Optional. Proxied attribute access is automatically handled by automatically handled by routines that get and set values based on
routines that get and set values based on the `attr` argument for the `attr` argument for this proxy.
this proxy.
If you would like to customize this behavior, you may supply a If you would like to customize this behavior, you may supply a
`getset_factory` callable that produces a tuple of `getter` and `getset_factory` callable that produces a tuple of `getter` and
`setter` functions. The factory is called with two arguments, the `setter` functions. The factory is called with two arguments, the
abstract type of the underlying collection and this proxy instance. abstract type of the underlying collection and this proxy instance.
proxy_factory :param proxy_factory: Optional. The type of collection to emulate is
Optional. The type of collection to emulate is determined by determined by sniffing the target collection. If your collection
sniffing the target collection. If your collection type can't be type can't be determined by duck typing or you'd like to use a
determined by duck typing or you'd like to use a different different collection implementation, you may supply a factory
collection implementation, you may supply a factory function to function to produce those collections. Only applicable to
produce those collections. Only applicable to non-scalar relationships. non-scalar relationships.
proxy_bulk_set :param proxy_bulk_set: Optional, use with proxy_factory. See
Optional, use with proxy_factory. See the _set() method for the _set() method for details.
details.
:param info: optional, will be assigned to
:attr:`.AssociationProxy.info` if present.
.. versionadded:: 1.0.9
""" """
self.target_collection = target_collection self.target_collection = target_collection
@ -131,36 +151,107 @@ class AssociationProxy(object):
self.proxy_factory = proxy_factory self.proxy_factory = proxy_factory
self.proxy_bulk_set = proxy_bulk_set self.proxy_bulk_set = proxy_bulk_set
self.scalar = None
self.owning_class = None self.owning_class = None
self.key = '_%s_%s_%s' % ( self.key = '_%s_%s_%s' % (
type(self).__name__, target_collection, id(self)) type(self).__name__, target_collection, id(self))
self.collection_class = None self.collection_class = None
if info:
self.info = info
@property
def remote_attr(self):
"""The 'remote' :class:`.MapperProperty` referenced by this
:class:`.AssociationProxy`.
.. versionadded:: 0.7.3
See also:
:attr:`.AssociationProxy.attr`
:attr:`.AssociationProxy.local_attr`
"""
return getattr(self.target_class, self.value_attr)
@property
def local_attr(self):
"""The 'local' :class:`.MapperProperty` referenced by this
:class:`.AssociationProxy`.
.. versionadded:: 0.7.3
See also:
:attr:`.AssociationProxy.attr`
:attr:`.AssociationProxy.remote_attr`
"""
return getattr(self.owning_class, self.target_collection)
@property
def attr(self):
"""Return a tuple of ``(local_attr, remote_attr)``.
This attribute is convenient when specifying a join
using :meth:`.Query.join` across two relationships::
sess.query(Parent).join(*Parent.proxied.attr)
.. versionadded:: 0.7.3
See also:
:attr:`.AssociationProxy.local_attr`
:attr:`.AssociationProxy.remote_attr`
"""
return (self.local_attr, self.remote_attr)
def _get_property(self): def _get_property(self):
return (orm.class_mapper(self.owning_class). return (orm.class_mapper(self.owning_class).
get_property(self.target_collection)) get_property(self.target_collection))
@property @util.memoized_property
def target_class(self): def target_class(self):
"""The class the proxy is attached to.""" """The intermediary class handled by this :class:`.AssociationProxy`.
Intercepted append/set/assignment events will result
in the generation of new instances of this class.
"""
return self._get_property().mapper.class_ return self._get_property().mapper.class_
def _target_is_scalar(self): @util.memoized_property
return not self._get_property().uselist def scalar(self):
"""Return ``True`` if this :class:`.AssociationProxy` proxies a scalar
relationship on the local side."""
scalar = not self._get_property().uselist
if scalar:
self._initialize_scalar_accessors()
return scalar
@util.memoized_property
def _value_is_scalar(self):
return not self._get_property().\
mapper.get_property(self.value_attr).uselist
@util.memoized_property
def _target_is_object(self):
return getattr(self.target_class, self.value_attr).impl.uses_objects
def __get__(self, obj, class_): def __get__(self, obj, class_):
if self.owning_class is None: if self.owning_class is None:
self.owning_class = class_ and class_ or type(obj) self.owning_class = class_ and class_ or type(obj)
if obj is None: if obj is None:
return self return self
elif self.scalar is None:
self.scalar = self._target_is_scalar()
if self.scalar:
self._initialize_scalar_accessors()
if self.scalar: if self.scalar:
return self._scalar_get(getattr(obj, self.target_collection)) target = getattr(obj, self.target_collection)
return self._scalar_get(target)
else: else:
try: try:
# If the owning instance is reborn (orm session resurrect, # If the owning instance is reborn (orm session resurrect,
@ -173,14 +264,10 @@ class AssociationProxy(object):
proxy = self._new(_lazy_collection(obj, self.target_collection)) proxy = self._new(_lazy_collection(obj, self.target_collection))
setattr(obj, self.key, (id(obj), proxy)) setattr(obj, self.key, (id(obj), proxy))
return proxy return proxy
def __set__(self, obj, values): def __set__(self, obj, values):
if self.owning_class is None: if self.owning_class is None:
self.owning_class = type(obj) self.owning_class = type(obj)
if self.scalar is None:
self.scalar = self._target_is_scalar()
if self.scalar:
self._initialize_scalar_accessors()
if self.scalar: if self.scalar:
creator = self.creator and self.creator or self.target_class creator = self.creator and self.creator or self.target_class
@ -209,7 +296,8 @@ class AssociationProxy(object):
def _default_getset(self, collection_class): def _default_getset(self, collection_class):
attr = self.value_attr attr = self.value_attr
getter = operator.attrgetter(attr) _getter = operator.attrgetter(attr)
getter = lambda target: _getter(target) if target is not None else None
if collection_class is dict: if collection_class is dict:
setter = lambda o, k, v: setattr(o, attr, v) setter = lambda o, k, v: setattr(o, attr, v)
else: else:
@ -221,21 +309,25 @@ class AssociationProxy(object):
self.collection_class = util.duck_type_collection(lazy_collection()) self.collection_class = util.duck_type_collection(lazy_collection())
if self.proxy_factory: if self.proxy_factory:
return self.proxy_factory(lazy_collection, creator, self.value_attr, self) return self.proxy_factory(
lazy_collection, creator, self.value_attr, self)
if self.getset_factory: if self.getset_factory:
getter, setter = self.getset_factory(self.collection_class, self) getter, setter = self.getset_factory(self.collection_class, self)
else: else:
getter, setter = self._default_getset(self.collection_class) getter, setter = self._default_getset(self.collection_class)
if self.collection_class is list: if self.collection_class is list:
return _AssociationList(lazy_collection, creator, getter, setter, self) return _AssociationList(
lazy_collection, creator, getter, setter, self)
elif self.collection_class is dict: elif self.collection_class is dict:
return _AssociationDict(lazy_collection, creator, getter, setter, self) return _AssociationDict(
lazy_collection, creator, getter, setter, self)
elif self.collection_class is set: elif self.collection_class is set:
return _AssociationSet(lazy_collection, creator, getter, setter, self) return _AssociationSet(
lazy_collection, creator, getter, setter, self)
else: else:
raise exceptions.ArgumentError( raise exc.ArgumentError(
'could not guess which interface to use for ' 'could not guess which interface to use for '
'collection_class "%s" backing "%s"; specify a ' 'collection_class "%s" backing "%s"; specify a '
'proxy_factory and proxy_bulk_set manually' % 'proxy_factory and proxy_bulk_set manually' %
@ -248,7 +340,7 @@ class AssociationProxy(object):
getter, setter = self.getset_factory(self.collection_class, self) getter, setter = self.getset_factory(self.collection_class, self)
else: else:
getter, setter = self._default_getset(self.collection_class) getter, setter = self._default_getset(self.collection_class)
proxy.creator = creator proxy.creator = creator
proxy.getter = getter proxy.getter = getter
proxy.setter = setter proxy.setter = setter
@ -263,28 +355,102 @@ class AssociationProxy(object):
elif self.collection_class is set: elif self.collection_class is set:
proxy.update(values) proxy.update(values)
else: else:
raise exceptions.ArgumentError( raise exc.ArgumentError(
'no proxy_bulk_set supplied for custom ' 'no proxy_bulk_set supplied for custom '
'collection_class implementation') 'collection_class implementation')
@property @property
def _comparator(self): def _comparator(self):
return self._get_property().comparator return self._get_property().comparator
def any(self, criterion=None, **kwargs): def any(self, criterion=None, **kwargs):
return self._comparator.any(getattr(self.target_class, self.value_attr).has(criterion, **kwargs)) """Produce a proxied 'any' expression using EXISTS.
This expression will be a composed product
using the :meth:`.RelationshipProperty.Comparator.any`
and/or :meth:`.RelationshipProperty.Comparator.has`
operators of the underlying proxied attributes.
"""
if self._target_is_object:
if self._value_is_scalar:
value_expr = getattr(
self.target_class, self.value_attr).has(
criterion, **kwargs)
else:
value_expr = getattr(
self.target_class, self.value_attr).any(
criterion, **kwargs)
else:
value_expr = criterion
# check _value_is_scalar here, otherwise
# we're scalar->scalar - call .any() so that
# the "can't call any() on a scalar" msg is raised.
if self.scalar and not self._value_is_scalar:
return self._comparator.has(
value_expr
)
else:
return self._comparator.any(
value_expr
)
def has(self, criterion=None, **kwargs): def has(self, criterion=None, **kwargs):
return self._comparator.has(getattr(self.target_class, self.value_attr).has(criterion, **kwargs)) """Produce a proxied 'has' expression using EXISTS.
This expression will be a composed product
using the :meth:`.RelationshipProperty.Comparator.any`
and/or :meth:`.RelationshipProperty.Comparator.has`
operators of the underlying proxied attributes.
"""
if self._target_is_object:
return self._comparator.has(
getattr(self.target_class, self.value_attr).
has(criterion, **kwargs)
)
else:
if criterion is not None or kwargs:
raise exc.ArgumentError(
"Non-empty has() not allowed for "
"column-targeted association proxy; use ==")
return self._comparator.has()
def contains(self, obj): def contains(self, obj):
return self._comparator.any(**{self.value_attr: obj}) """Produce a proxied 'contains' expression using EXISTS.
This expression will be a composed product
using the :meth:`.RelationshipProperty.Comparator.any`
, :meth:`.RelationshipProperty.Comparator.has`,
and/or :meth:`.RelationshipProperty.Comparator.contains`
operators of the underlying proxied attributes.
"""
if self.scalar and not self._value_is_scalar:
return self._comparator.has(
getattr(self.target_class, self.value_attr).contains(obj)
)
else:
return self._comparator.any(**{self.value_attr: obj})
def __eq__(self, obj): def __eq__(self, obj):
return self._comparator.has(**{self.value_attr: obj}) # note the has() here will fail for collections; eq_()
# is only allowed with a scalar.
if obj is None:
return or_(
self._comparator.has(**{self.value_attr: obj}),
self._comparator == None
)
else:
return self._comparator.has(**{self.value_attr: obj})
def __ne__(self, obj): def __ne__(self, obj):
return not_(self.__eq__(obj)) # note the has() here will fail for collections; eq_()
# is only allowed with a scalar.
return self._comparator.has(
getattr(self.target_class, self.value_attr) != obj)
class _lazy_collection(object): class _lazy_collection(object):
@ -295,22 +461,23 @@ class _lazy_collection(object):
def __call__(self): def __call__(self):
obj = self.ref() obj = self.ref()
if obj is None: if obj is None:
raise exceptions.InvalidRequestError( raise exc.InvalidRequestError(
"stale association proxy, parent object has gone out of " "stale association proxy, parent object has gone out of "
"scope") "scope")
return getattr(obj, self.target) return getattr(obj, self.target)
def __getstate__(self): def __getstate__(self):
return {'obj':self.ref(), 'target':self.target} return {'obj': self.ref(), 'target': self.target}
def __setstate__(self, state): def __setstate__(self, state):
self.ref = weakref.ref(state['obj']) self.ref = weakref.ref(state['obj'])
self.target = state['target'] self.target = state['target']
class _AssociationCollection(object): class _AssociationCollection(object):
def __init__(self, lazy_collection, creator, getter, setter, parent): def __init__(self, lazy_collection, creator, getter, setter, parent):
"""Constructs an _AssociationCollection. """Constructs an _AssociationCollection.
This will always be a subclass of either _AssociationList, This will always be a subclass of either _AssociationList,
_AssociationSet, or _AssociationDict. _AssociationSet, or _AssociationDict.
@ -344,17 +511,20 @@ class _AssociationCollection(object):
def __len__(self): def __len__(self):
return len(self.col) return len(self.col)
def __nonzero__(self): def __bool__(self):
return bool(self.col) return bool(self.col)
__nonzero__ = __bool__
def __getstate__(self): def __getstate__(self):
return {'parent':self.parent, 'lazy_collection':self.lazy_collection} return {'parent': self.parent, 'lazy_collection': self.lazy_collection}
def __setstate__(self, state): def __setstate__(self, state):
self.parent = state['parent'] self.parent = state['parent']
self.lazy_collection = state['lazy_collection'] self.lazy_collection = state['lazy_collection']
self.parent._inflate(self) self.parent._inflate(self)
class _AssociationList(_AssociationCollection): class _AssociationList(_AssociationCollection):
"""Generic, converting, list-to-list proxy.""" """Generic, converting, list-to-list proxy."""
@ -368,7 +538,10 @@ class _AssociationList(_AssociationCollection):
return self.setter(object, value) return self.setter(object, value)
def __getitem__(self, index): def __getitem__(self, index):
return self._get(self.col[index]) if not isinstance(index, slice):
return self._get(self.col[index])
else:
return [self._get(member) for member in self.col[index]]
def __setitem__(self, index, value): def __setitem__(self, index, value):
if not isinstance(index, slice): if not isinstance(index, slice):
@ -382,11 +555,12 @@ class _AssociationList(_AssociationCollection):
stop = index.stop stop = index.stop
step = index.step or 1 step = index.step or 1
rng = range(index.start or 0, stop, step) start = index.start or 0
rng = list(range(index.start or 0, stop, step))
if step == 1: if step == 1:
for i in rng: for i in rng:
del self[index.start] del self[start]
i = index.start i = start
for item in value: for item in value:
self.insert(i, item) self.insert(i, item)
i += 1 i += 1
@ -429,7 +603,7 @@ class _AssociationList(_AssociationCollection):
for member in self.col: for member in self.col:
yield self._get(member) yield self._get(member)
raise StopIteration return
def append(self, value): def append(self, value):
item = self._create(value) item = self._create(value)
@ -437,7 +611,7 @@ class _AssociationList(_AssociationCollection):
def count(self, value): def count(self, value):
return sum([1 for _ in return sum([1 for _ in
itertools.ifilter(lambda v: v == value, iter(self))]) util.itertools_filter(lambda v: v == value, iter(self))])
def extend(self, values): def extend(self, values):
for v in values: for v in values:
@ -536,14 +710,16 @@ class _AssociationList(_AssociationCollection):
def __hash__(self): def __hash__(self):
raise TypeError("%s objects are unhashable" % type(self).__name__) raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in locals().items(): for func_name, func in list(locals().items()):
if (util.callable(func) and func.func_name == func_name and if (util.callable(func) and func.__name__ == func_name and
not func.__doc__ and hasattr(list, func_name)): not func.__doc__ and hasattr(list, func_name)):
func.__doc__ = getattr(list, func_name).__doc__ func.__doc__ = getattr(list, func_name).__doc__
del func_name, func del func_name, func
_NotProvided = util.symbol('_NotProvided') _NotProvided = util.symbol('_NotProvided')
class _AssociationDict(_AssociationCollection): class _AssociationDict(_AssociationCollection):
"""Generic, converting, dict-to-dict proxy.""" """Generic, converting, dict-to-dict proxy."""
@ -577,7 +753,7 @@ class _AssociationDict(_AssociationCollection):
return key in self.col return key in self.col
def __iter__(self): def __iter__(self):
return self.col.iterkeys() return iter(self.col.keys())
def clear(self): def clear(self):
self.col.clear() self.col.clear()
@ -622,24 +798,27 @@ class _AssociationDict(_AssociationCollection):
def keys(self): def keys(self):
return self.col.keys() return self.col.keys()
def iterkeys(self): if util.py2k:
return self.col.iterkeys() def iteritems(self):
return ((key, self._get(self.col[key])) for key in self.col)
def values(self): def itervalues(self):
return [ self._get(member) for member in self.col.values() ] return (self._get(self.col[key]) for key in self.col)
def itervalues(self): def iterkeys(self):
for key in self.col: return self.col.iterkeys()
yield self._get(self.col[key])
raise StopIteration
def items(self): def values(self):
return [(k, self._get(self.col[k])) for k in self] return [self._get(member) for member in self.col.values()]
def iteritems(self): def items(self):
for key in self.col: return [(k, self._get(self.col[k])) for k in self]
yield (key, self._get(self.col[key])) else:
raise StopIteration def items(self):
return ((key, self._get(self.col[key])) for key in self.col)
def values(self):
return (self._get(self.col[key]) for key in self.col)
def pop(self, key, default=_NotProvided): def pop(self, key, default=_NotProvided):
if default is _NotProvided: if default is _NotProvided:
@ -658,11 +837,20 @@ class _AssociationDict(_AssociationCollection):
len(a)) len(a))
elif len(a) == 1: elif len(a) == 1:
seq_or_map = a[0] seq_or_map = a[0]
for item in seq_or_map: # discern dict from sequence - took the advice from
if isinstance(item, tuple): # http://www.voidspace.org.uk/python/articles/duck_typing.shtml
self[item[0]] = item[1] # still not perfect :(
else: if hasattr(seq_or_map, 'keys'):
for item in seq_or_map:
self[item] = seq_or_map[item] self[item] = seq_or_map[item]
else:
try:
for k, v in seq_or_map:
self[k] = v
except ValueError:
raise ValueError(
"dictionary update sequence "
"requires 2-element tuples")
for key, value in kw: for key, value in kw:
self[key] = value self[key] = value
@ -673,9 +861,9 @@ class _AssociationDict(_AssociationCollection):
def __hash__(self): def __hash__(self):
raise TypeError("%s objects are unhashable" % type(self).__name__) raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in locals().items(): for func_name, func in list(locals().items()):
if (util.callable(func) and func.func_name == func_name and if (util.callable(func) and func.__name__ == func_name and
not func.__doc__ and hasattr(dict, func_name)): not func.__doc__ and hasattr(dict, func_name)):
func.__doc__ = getattr(dict, func_name).__doc__ func.__doc__ = getattr(dict, func_name).__doc__
del func_name, func del func_name, func
@ -695,12 +883,14 @@ class _AssociationSet(_AssociationCollection):
def __len__(self): def __len__(self):
return len(self.col) return len(self.col)
def __nonzero__(self): def __bool__(self):
if self.col: if self.col:
return True return True
else: else:
return False return False
__nonzero__ = __bool__
def __contains__(self, value): def __contains__(self, value):
for member in self.col: for member in self.col:
# testlib.pragma exempt:__eq__ # testlib.pragma exempt:__eq__
@ -717,7 +907,7 @@ class _AssociationSet(_AssociationCollection):
""" """
for member in self.col: for member in self.col:
yield self._get(member) yield self._get(member)
raise StopIteration return
def add(self, value): def add(self, value):
if value not in self: if value not in self:
@ -871,8 +1061,8 @@ class _AssociationSet(_AssociationCollection):
def __hash__(self): def __hash__(self):
raise TypeError("%s objects are unhashable" % type(self).__name__) raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in locals().items(): for func_name, func in list(locals().items()):
if (util.callable(func) and func.func_name == func_name and if (util.callable(func) and func.__name__ == func_name and
not func.__doc__ and hasattr(set, func_name)): not func.__doc__ and hasattr(set, func_name)):
func.__doc__ = getattr(set, func_name).__doc__ func.__doc__ = getattr(set, func_name).__doc__
del func_name, func del func_name, func

View File

@ -1,31 +1,39 @@
"""Provides an API for creation of custom ClauseElements and compilers. # ext/compiler.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
r"""Provides an API for creation of custom ClauseElements and compilers.
Synopsis Synopsis
======== ========
Usage involves the creation of one or more :class:`~sqlalchemy.sql.expression.ClauseElement` Usage involves the creation of one or more
subclasses and one or more callables defining its compilation:: :class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or
more callables defining its compilation::
from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import ColumnClause from sqlalchemy.sql.expression import ColumnClause
class MyColumn(ColumnClause): class MyColumn(ColumnClause):
pass pass
@compiles(MyColumn) @compiles(MyColumn)
def compile_mycolumn(element, compiler, **kw): def compile_mycolumn(element, compiler, **kw):
return "[%s]" % element.name return "[%s]" % element.name
Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`, Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`,
the base expression element for named column objects. The ``compiles`` the base expression element for named column objects. The ``compiles``
decorator registers itself with the ``MyColumn`` class so that it is invoked decorator registers itself with the ``MyColumn`` class so that it is invoked
when the object is compiled to a string:: when the object is compiled to a string::
from sqlalchemy import select from sqlalchemy import select
s = select([MyColumn('x'), MyColumn('y')]) s = select([MyColumn('x'), MyColumn('y')])
print str(s) print str(s)
Produces:: Produces::
SELECT [x], [y] SELECT [x], [y]
@ -50,22 +58,25 @@ invoked for the dialect in use::
@compiles(AlterColumn, 'postgresql') @compiles(AlterColumn, 'postgresql')
def visit_alter_column(element, compiler, **kw): def visit_alter_column(element, compiler, **kw):
return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, element.column.name) return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name,
element.column.name)
The second ``visit_alter_table`` will be invoked when any ``postgresql`` dialect is used. The second ``visit_alter_table`` will be invoked when any ``postgresql``
dialect is used.
Compiling sub-elements of a custom expression construct Compiling sub-elements of a custom expression construct
======================================================= =======================================================
The ``compiler`` argument is the :class:`~sqlalchemy.engine.base.Compiled` The ``compiler`` argument is the
object in use. This object can be inspected for any information about the :class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object
in-progress compilation, including ``compiler.dialect``, can be inspected for any information about the in-progress compilation,
``compiler.statement`` etc. The :class:`~sqlalchemy.sql.compiler.SQLCompiler` including ``compiler.dialect``, ``compiler.statement`` etc. The
and :class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()`` :class:`~sqlalchemy.sql.compiler.SQLCompiler` and
:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()``
method which can be used for compilation of embedded attributes:: method which can be used for compilation of embedded attributes::
from sqlalchemy.sql.expression import Executable, ClauseElement from sqlalchemy.sql.expression import Executable, ClauseElement
class InsertFromSelect(Executable, ClauseElement): class InsertFromSelect(Executable, ClauseElement):
def __init__(self, table, select): def __init__(self, table, select):
self.table = table self.table = table
@ -80,36 +91,110 @@ method which can be used for compilation of embedded attributes::
insert = InsertFromSelect(t1, select([t1]).where(t1.c.x>5)) insert = InsertFromSelect(t1, select([t1]).where(t1.c.x>5))
print insert print insert
Produces:: Produces::
"INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)" "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z
FROM mytable WHERE mytable.x > :x_1)"
.. note::
The above ``InsertFromSelect`` construct is only an example, this actual
functionality is already available using the
:meth:`.Insert.from_select` method.
.. note::
The above ``InsertFromSelect`` construct probably wants to have "autocommit"
enabled. See :ref:`enabling_compiled_autocommit` for this step.
Cross Compiling between SQL and DDL compilers Cross Compiling between SQL and DDL compilers
--------------------------------------------- ---------------------------------------------
SQL and DDL constructs are each compiled using different base compilers - ``SQLCompiler`` SQL and DDL constructs are each compiled using different base compilers -
and ``DDLCompiler``. A common need is to access the compilation rules of SQL expressions ``SQLCompiler`` and ``DDLCompiler``. A common need is to access the
from within a DDL expression. The ``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as below where we generate a CHECK compilation rules of SQL expressions from within a DDL expression. The
constraint that embeds a SQL expression:: ``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as
below where we generate a CHECK constraint that embeds a SQL expression::
@compiles(MyConstraint) @compiles(MyConstraint)
def compile_my_constraint(constraint, ddlcompiler, **kw): def compile_my_constraint(constraint, ddlcompiler, **kw):
return "CONSTRAINT %s CHECK (%s)" % ( return "CONSTRAINT %s CHECK (%s)" % (
constraint.name, constraint.name,
ddlcompiler.sql_compiler.process(constraint.expression) ddlcompiler.sql_compiler.process(
constraint.expression, literal_binds=True)
) )
Above, we add an additional flag to the process step as called by
:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This
indicates that any SQL expression which refers to a :class:`.BindParameter`
object or other "literal" object such as those which refer to strings or
integers should be rendered **in-place**, rather than being referred to as
a bound parameter; when emitting DDL, bound parameters are typically not
supported.
.. _enabling_compiled_autocommit:
Enabling Autocommit on a Construct
==================================
Recall from the section :ref:`autocommit` that the :class:`.Engine`, when
asked to execute a construct in the absence of a user-defined transaction,
detects if the given construct represents DML or DDL, that is, a data
modification or data definition statement, which requires (or may require,
in the case of DDL) that the transaction generated by the DBAPI be committed
(recall that DBAPI always has a transaction going on regardless of what
SQLAlchemy does). Checking for this is actually accomplished by checking for
the "autocommit" execution option on the construct. When building a
construct like an INSERT derivation, a new DDL type, or perhaps a stored
procedure that alters data, the "autocommit" option needs to be set in order
for the statement to function with "connectionless" execution
(as described in :ref:`dbengine_implicit`).
Currently a quick way to do this is to subclass :class:`.Executable`, then
add the "autocommit" flag to the ``_execution_options`` dictionary (note this
is a "frozen" dictionary which supplies a generative ``union()`` method)::
from sqlalchemy.sql.expression import Executable, ClauseElement
class MyInsertThing(Executable, ClauseElement):
_execution_options = \
Executable._execution_options.union({'autocommit': True})
More succinctly, if the construct is truly similar to an INSERT, UPDATE, or
DELETE, :class:`.UpdateBase` can be used, which already is a subclass
of :class:`.Executable`, :class:`.ClauseElement` and includes the
``autocommit`` flag::
from sqlalchemy.sql.expression import UpdateBase
class MyInsertThing(UpdateBase):
def __init__(self, ...):
...
DDL elements that subclass :class:`.DDLElement` already have the
"autocommit" flag turned on.
Changing the default compilation of existing constructs Changing the default compilation of existing constructs
======================================================= =======================================================
The compiler extension applies just as well to the existing constructs. When overriding The compiler extension applies just as well to the existing constructs. When
the compilation of a built in SQL construct, the @compiles decorator is invoked upon overriding the compilation of a built in SQL construct, the @compiles
the appropriate class (be sure to use the class, i.e. ``Insert`` or ``Select``, instead of the creation function such as ``insert()`` or ``select()``). decorator is invoked upon the appropriate class (be sure to use the class,
i.e. ``Insert`` or ``Select``, instead of the creation function such
as ``insert()`` or ``select()``).
Within the new compilation function, to get at the "original" compilation routine, Within the new compilation function, to get at the "original" compilation
use the appropriate visit_XXX method - this because compiler.process() will call upon the routine, use the appropriate visit_XXX method - this
overriding routine and cause an endless loop. Such as, to add "prefix" to all insert statements:: because compiler.process() will call upon the overriding routine and cause
an endless loop. Such as, to add "prefix" to all insert statements::
from sqlalchemy.sql.expression import Insert from sqlalchemy.sql.expression import Insert
@ -117,38 +202,77 @@ overriding routine and cause an endless loop. Such as, to add "prefix" to all
def prefix_inserts(insert, compiler, **kw): def prefix_inserts(insert, compiler, **kw):
return compiler.visit_insert(insert.prefix_with("some prefix"), **kw) return compiler.visit_insert(insert.prefix_with("some prefix"), **kw)
The above compiler will prefix all INSERT statements with "some prefix" when compiled. The above compiler will prefix all INSERT statements with "some prefix" when
compiled.
.. _type_compilation_extension:
Changing Compilation of Types
=============================
``compiler`` works for types, too, such as below where we implement the
MS-SQL specific 'max' keyword for ``String``/``VARCHAR``::
@compiles(String, 'mssql')
@compiles(VARCHAR, 'mssql')
def compile_varchar(element, compiler, **kw):
if element.length == 'max':
return "VARCHAR('max')"
else:
return compiler.visit_VARCHAR(element, **kw)
foo = Table('foo', metadata,
Column('data', VARCHAR('max'))
)
Subclassing Guidelines Subclassing Guidelines
====================== ======================
A big part of using the compiler extension is subclassing SQLAlchemy expression constructs. To make this easier, the expression and schema packages feature a set of "bases" intended for common tasks. A synopsis is as follows: A big part of using the compiler extension is subclassing SQLAlchemy
expression constructs. To make this easier, the expression and
schema packages feature a set of "bases" intended for common tasks.
A synopsis is as follows:
* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root * :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root
expression class. Any SQL expression can be derived from this base, and is expression class. Any SQL expression can be derived from this base, and is
probably the best choice for longer constructs such as specialized INSERT probably the best choice for longer constructs such as specialized INSERT
statements. statements.
* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all * :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all
"column-like" elements. Anything that you'd place in the "columns" clause of "column-like" elements. Anything that you'd place in the "columns" clause of
a SELECT statement (as well as order by and group by) can derive from this - a SELECT statement (as well as order by and group by) can derive from this -
the object will automatically have Python "comparison" behavior. the object will automatically have Python "comparison" behavior.
:class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a :class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a
``type`` member which is expression's return type. This can be established ``type`` member which is expression's return type. This can be established
at the instance level in the constructor, or at the class level if its at the instance level in the constructor, or at the class level if its
generally constant:: generally constant::
class timestamp(ColumnElement): class timestamp(ColumnElement):
type = TIMESTAMP() type = TIMESTAMP()
* :class:`~sqlalchemy.sql.expression.FunctionElement` - This is a hybrid of a * :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a
``ColumnElement`` and a "from clause" like object, and represents a SQL ``ColumnElement`` and a "from clause" like object, and represents a SQL
function or stored procedure type of call. Since most databases support function or stored procedure type of call. Since most databases support
statements along the line of "SELECT FROM <some function>" statements along the line of "SELECT FROM <some function>"
``FunctionElement`` adds in the ability to be used in the FROM clause of a ``FunctionElement`` adds in the ability to be used in the FROM clause of a
``select()`` construct. ``select()`` construct::
from sqlalchemy.sql.expression import FunctionElement
class coalesce(FunctionElement):
name = 'coalesce'
@compiles(coalesce)
def compile(element, compiler, **kw):
return "coalesce(%s)" % compiler.process(element.clauses)
@compiles(coalesce, 'oracle')
def compile(element, compiler, **kw):
if len(element.clauses) > 2:
raise TypeError("coalesce only supports two arguments on Oracle")
return "nvl(%s)" % compiler.process(element.clauses)
* :class:`~sqlalchemy.schema.DDLElement` - The root of all DDL expressions, * :class:`~sqlalchemy.schema.DDLElement` - The root of all DDL expressions,
like CREATE TABLE, ALTER TABLE, etc. Compilation of ``DDLElement`` like CREATE TABLE, ALTER TABLE, etc. Compilation of ``DDLElement``
subclasses is issued by a ``DDLCompiler`` instead of a ``SQLCompiler``. subclasses is issued by a ``DDLCompiler`` instead of a ``SQLCompiler``.
@ -156,39 +280,195 @@ A big part of using the compiler extension is subclassing SQLAlchemy expression
``execute_at()`` method, allowing the construct to be invoked during CREATE ``execute_at()`` method, allowing the construct to be invoked during CREATE
TABLE and DROP TABLE sequences. TABLE and DROP TABLE sequences.
* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which should be * :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which
used with any expression class that represents a "standalone" SQL statement that should be used with any expression class that represents a "standalone"
can be passed directly to an ``execute()`` method. It is already implicit SQL statement that can be passed directly to an ``execute()`` method. It
within ``DDLElement`` and ``FunctionElement``. is already implicit within ``DDLElement`` and ``FunctionElement``.
Further Examples
================
"UTC timestamp" function
-------------------------
A function that works like "CURRENT_TIMESTAMP" except applies the
appropriate conversions so that the time is in UTC time. Timestamps are best
stored in relational databases as UTC, without time zones. UTC so that your
database doesn't think time has gone backwards in the hour when daylight
savings ends, without timezones because timezones are like character
encodings - they're best applied only at the endpoints of an application
(i.e. convert to UTC upon user input, re-apply desired timezone upon display).
For PostgreSQL and Microsoft SQL Server::
from sqlalchemy.sql import expression
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import DateTime
class utcnow(expression.FunctionElement):
type = DateTime()
@compiles(utcnow, 'postgresql')
def pg_utcnow(element, compiler, **kw):
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
@compiles(utcnow, 'mssql')
def ms_utcnow(element, compiler, **kw):
return "GETUTCDATE()"
Example usage::
from sqlalchemy import (
Table, Column, Integer, String, DateTime, MetaData
)
metadata = MetaData()
event = Table("event", metadata,
Column("id", Integer, primary_key=True),
Column("description", String(50), nullable=False),
Column("timestamp", DateTime, server_default=utcnow())
)
"GREATEST" function
-------------------
The "GREATEST" function is given any number of arguments and returns the one
that is of the highest value - its equivalent to Python's ``max``
function. A SQL standard version versus a CASE based version which only
accommodates two arguments::
from sqlalchemy.sql import expression
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import Numeric
class greatest(expression.FunctionElement):
type = Numeric()
name = 'greatest'
@compiles(greatest)
def default_greatest(element, compiler, **kw):
return compiler.visit_function(element)
@compiles(greatest, 'sqlite')
@compiles(greatest, 'mssql')
@compiles(greatest, 'oracle')
def case_greatest(element, compiler, **kw):
arg1, arg2 = list(element.clauses)
return "CASE WHEN %s > %s THEN %s ELSE %s END" % (
compiler.process(arg1),
compiler.process(arg2),
compiler.process(arg1),
compiler.process(arg2),
)
Example usage::
Session.query(Account).\
filter(
greatest(
Account.checking_balance,
Account.savings_balance) > 10000
)
"false" expression
------------------
Render a "false" constant expression, rendering as "0" on platforms that
don't have a "false" constant::
from sqlalchemy.sql import expression
from sqlalchemy.ext.compiler import compiles
class sql_false(expression.ColumnElement):
pass
@compiles(sql_false)
def default_false(element, compiler, **kw):
return "false"
@compiles(sql_false, 'mssql')
@compiles(sql_false, 'mysql')
@compiles(sql_false, 'oracle')
def int_false(element, compiler, **kw):
return "0"
Example usage::
from sqlalchemy import select, union_all
exp = union_all(
select([users.c.name, sql_false().label("enrolled")]),
select([customers.c.name, customers.c.enrolled])
)
""" """
from .. import exc
from ..sql import visitors
def compiles(class_, *specs): def compiles(class_, *specs):
"""Register a function as a compiler for a
given :class:`.ClauseElement` type."""
def decorate(fn): def decorate(fn):
existing = getattr(class_, '_compiler_dispatcher', None) # get an existing @compiles handler
existing = class_.__dict__.get('_compiler_dispatcher', None)
# get the original handler. All ClauseElement classes have one
# of these, but some TypeEngine classes will not.
existing_dispatch = getattr(class_, '_compiler_dispatch', None)
if not existing: if not existing:
existing = _dispatcher() existing = _dispatcher()
if existing_dispatch:
def _wrap_existing_dispatch(element, compiler, **kw):
try:
return existing_dispatch(element, compiler, **kw)
except exc.UnsupportedCompilationError:
raise exc.CompileError(
"%s construct has no default "
"compilation handler." % type(element))
existing.specs['default'] = _wrap_existing_dispatch
# TODO: why is the lambda needed ? # TODO: why is the lambda needed ?
setattr(class_, '_compiler_dispatch', lambda *arg, **kw: existing(*arg, **kw)) setattr(class_, '_compiler_dispatch',
lambda *arg, **kw: existing(*arg, **kw))
setattr(class_, '_compiler_dispatcher', existing) setattr(class_, '_compiler_dispatcher', existing)
if specs: if specs:
for s in specs: for s in specs:
existing.specs[s] = fn existing.specs[s] = fn
else: else:
existing.specs['default'] = fn existing.specs['default'] = fn
return fn return fn
return decorate return decorate
def deregister(class_):
"""Remove all custom compilers associated with a given
:class:`.ClauseElement` type."""
if hasattr(class_, '_compiler_dispatcher'):
# regenerate default _compiler_dispatch
visitors._generate_dispatch(class_)
# remove custom directive
del class_._compiler_dispatcher
class _dispatcher(object): class _dispatcher(object):
def __init__(self): def __init__(self):
self.specs = {} self.specs = {}
def __call__(self, element, compiler, **kw): def __call__(self, element, compiler, **kw):
# TODO: yes, this could also switch off of DBAPI in use. # TODO: yes, this could also switch off of DBAPI in use.
fn = self.specs.get(compiler.dialect.name, None) fn = self.specs.get(compiler.dialect.name, None)
if not fn: if not fn:
fn = self.specs['default'] try:
fn = self.specs['default']
except KeyError:
raise exc.CompileError(
"%s construct has no default "
"compilation handler." % type(element))
return fn(element, compiler, **kw) return fn(element, compiler, **kw)

View File

@ -1,5 +1,6 @@
# horizontal_shard.py # ext/horizontal_shard.py
# Copyright (C) the SQLAlchemy authors and contributors # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -9,105 +10,54 @@
Defines a rudimental 'horizontal sharding' system which allows a Session to Defines a rudimental 'horizontal sharding' system which allows a Session to
distribute queries and persistence operations across multiple databases. distribute queries and persistence operations across multiple databases.
For a usage example, see the :ref:`examples_sharding` example included in For a usage example, see the :ref:`examples_sharding` example included in
the source distrbution. the source distribution.
""" """
import sqlalchemy.exceptions as sa_exc from .. import util
from sqlalchemy import util from ..orm.session import Session
from sqlalchemy.orm.session import Session from ..orm.query import Query
from sqlalchemy.orm.query import Query
__all__ = ['ShardedSession', 'ShardedQuery'] __all__ = ['ShardedSession', 'ShardedQuery']
class ShardedSession(Session):
def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs):
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped instance, and possibly a
SQL clause, returns a shard ID. This id may be based off of the
attributes present within the object, or on some round-robin
scheme. If the scheme is based on a selection, it should set
whatever state on the instance to mark it in the future as
participating in that shard.
:param id_chooser: A callable, passed a query and a tuple of identity values, which
should return a list of shard ids where the ID might reside. The
databases will be queried in the order of this listing.
:param query_chooser: For a given Query, returns the list of shard_ids where the query
should be issued. Results from all shards returned will be combined
together into a single listing.
:param shards: A dictionary of string shard names to :class:`~sqlalchemy.engine.base.Engine`
objects.
"""
super(ShardedSession, self).__init__(**kwargs)
self.shard_chooser = shard_chooser
self.id_chooser = id_chooser
self.query_chooser = query_chooser
self.__binds = {}
self._mapper_flush_opts = {'connection_callable':self.connection}
self._query_cls = ShardedQuery
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
if shard_id is None:
shard_id = self.shard_chooser(mapper, instance)
if self.transaction is not None:
return self.transaction.connection(mapper, shard_id=shard_id)
else:
return self.get_bind(mapper,
shard_id=shard_id,
instance=instance).contextual_connect(**kwargs)
def get_bind(self, mapper, shard_id=None, instance=None, clause=None, **kw):
if shard_id is None:
shard_id = self.shard_chooser(mapper, instance, clause=clause)
return self.__binds[shard_id]
def bind_shard(self, shard_id, bind):
self.__binds[shard_id] = bind
class ShardedQuery(Query): class ShardedQuery(Query):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ShardedQuery, self).__init__(*args, **kwargs) super(ShardedQuery, self).__init__(*args, **kwargs)
self.id_chooser = self.session.id_chooser self.id_chooser = self.session.id_chooser
self.query_chooser = self.session.query_chooser self.query_chooser = self.session.query_chooser
self._shard_id = None self._shard_id = None
def set_shard(self, shard_id): def set_shard(self, shard_id):
"""return a new query, limited to a single shard ID. """return a new query, limited to a single shard ID.
all subsequent operations with the returned query will all subsequent operations with the returned query will
be against the single shard regardless of other state. be against the single shard regardless of other state.
""" """
q = self._clone() q = self._clone()
q._shard_id = shard_id q._shard_id = shard_id
return q return q
def _execute_and_instances(self, context): def _execute_and_instances(self, context):
if self._shard_id is not None: def iter_for_shard(shard_id):
result = self.session.connection( context.attributes['shard_id'] = shard_id
mapper=self._mapper_zero(), result = self._connection_from_session(
shard_id=self._shard_id).execute(context.statement, self._params) mapper=self._mapper_zero(),
shard_id=shard_id).execute(
context.statement,
self._params)
return self.instances(result, context) return self.instances(result, context)
if self._shard_id is not None:
return iter_for_shard(self._shard_id)
else: else:
partial = [] partial = []
for shard_id in self.query_chooser(self): for shard_id in self.query_chooser(self):
result = self.session.connection( partial.extend(iter_for_shard(shard_id))
mapper=self._mapper_zero(),
shard_id=shard_id).execute(context.statement, self._params) # if some kind of in memory 'sorting'
partial = partial + list(self.instances(result, context))
# if some kind of in memory 'sorting'
# were done, this is where it would happen # were done, this is where it would happen
return iter(partial) return iter(partial)
@ -122,4 +72,60 @@ class ShardedQuery(Query):
return o return o
else: else:
return None return None
class ShardedSession(Session):
def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None,
query_cls=ShardedQuery, **kwargs):
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped
instance, and possibly a SQL clause, returns a shard ID. This id
may be based off of the attributes present within the object, or on
some round-robin scheme. If the scheme is based on a selection, it
should set whatever state on the instance to mark it in the future as
participating in that shard.
:param id_chooser: A callable, passed a query and a tuple of identity
values, which should return a list of shard ids where the ID might
reside. The databases will be queried in the order of this listing.
:param query_chooser: For a given Query, returns the list of shard_ids
where the query should be issued. Results from all shards returned
will be combined together into a single listing.
:param shards: A dictionary of string shard names
to :class:`~sqlalchemy.engine.Engine` objects.
"""
super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs)
self.shard_chooser = shard_chooser
self.id_chooser = id_chooser
self.query_chooser = query_chooser
self.__binds = {}
self.connection_callable = self.connection
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
if shard_id is None:
shard_id = self.shard_chooser(mapper, instance)
if self.transaction is not None:
return self.transaction.connection(mapper, shard_id=shard_id)
else:
return self.get_bind(
mapper,
shard_id=shard_id,
instance=instance
).contextual_connect(**kwargs)
def get_bind(self, mapper, shard_id=None,
instance=None, clause=None, **kw):
if shard_id is None:
shard_id = self.shard_chooser(mapper, instance, clause=clause)
return self.__binds[shard_id]
def bind_shard(self, shard_id, bind):
self.__binds[shard_id] = bind

View File

@ -1,58 +1,78 @@
"""A custom list that manages index/position information for its children. # ext/orderinglist.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""A custom list that manages index/position information for contained
elements.
:author: Jason Kirtland :author: Jason Kirtland
``orderinglist`` is a helper for mutable ordered relationships. It will intercept ``orderinglist`` is a helper for mutable ordered relationships. It will
list operations performed on a relationship collection and automatically intercept list operations performed on a :func:`.relationship`-managed
synchronize changes in list position with an attribute on the related objects. collection and
(See :ref:`advdatamapping_entitycollections` for more information on the general pattern.) automatically synchronize changes in list position onto a target scalar
attribute.
Example: Two tables that store slides in a presentation. Each slide Example: A ``slide`` table, where each row refers to zero or more entries
has a number of bullet points, displayed in order by the 'position' in a related ``bullet`` table. The bullets within a slide are
column on the bullets table. These bullets can be inserted and re-ordered displayed in order based on the value of the ``position`` column in the
by your end users, and you need to update the 'position' column of all ``bullet`` table. As entries are reordered in memory, the value of the
affected rows when changes are made. ``position`` attribute should be updated to reflect the new sort order::
.. sourcecode:: python+sql
slides_table = Table('Slides', metadata, Base = declarative_base()
Column('id', Integer, primary_key=True),
Column('name', String))
bullets_table = Table('Bullets', metadata, class Slide(Base):
Column('id', Integer, primary_key=True), __tablename__ = 'slide'
Column('slide_id', Integer, ForeignKey('Slides.id')),
Column('position', Integer),
Column('text', String))
class Slide(object): id = Column(Integer, primary_key=True)
pass name = Column(String)
class Bullet(object):
pass
mapper(Slide, slides_table, properties={ bullets = relationship("Bullet", order_by="Bullet.position")
'bullets': relationship(Bullet, order_by=[bullets_table.c.position])
})
mapper(Bullet, bullets_table)
The standard relationship mapping will produce a list-like attribute on each Slide class Bullet(Base):
containing all related Bullets, but coping with changes in ordering is totally __tablename__ = 'bullet'
your responsibility. If you insert a Bullet into that list, there is no id = Column(Integer, primary_key=True)
magic- it won't have a position attribute unless you assign it it one, and slide_id = Column(Integer, ForeignKey('slide.id'))
you'll need to manually renumber all the subsequent Bullets in the list to position = Column(Integer)
accommodate the insert. text = Column(String)
An ``orderinglist`` can automate this and manage the 'position' attribute on all The standard relationship mapping will produce a list-like attribute on each
related bullets for you. ``Slide`` containing all related ``Bullet`` objects,
but coping with changes in ordering is not handled automatically.
When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position``
attribute will remain unset until manually assigned. When the ``Bullet``
is inserted into the middle of the list, the following ``Bullet`` objects
will also need to be renumbered.
.. sourcecode:: python+sql The :class:`.OrderingList` object automates this task, managing the
``position`` attribute on all ``Bullet`` objects in the collection. It is
mapper(Slide, slides_table, properties={ constructed using the :func:`.ordering_list` factory::
'bullets': relationship(Bullet,
collection_class=ordering_list('position'), from sqlalchemy.ext.orderinglist import ordering_list
order_by=[bullets_table.c.position])
}) Base = declarative_base()
mapper(Bullet, bullets_table)
class Slide(Base):
__tablename__ = 'slide'
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship("Bullet", order_by="Bullet.position",
collection_class=ordering_list('position'))
class Bullet(Base):
__tablename__ = 'bullet'
id = Column(Integer, primary_key=True)
slide_id = Column(Integer, ForeignKey('slide.id'))
position = Column(Integer)
text = Column(String)
With the above mapping the ``Bullet.position`` attribute is managed::
s = Slide() s = Slide()
s.bullets.append(Bullet()) s.bullets.append(Bullet())
@ -63,71 +83,98 @@ related bullets for you.
s.bullets[2].position s.bullets[2].position
>>> 2 >>> 2
Use the ``ordering_list`` function to set up the ``collection_class`` on relationships The :class:`.OrderingList` construct only works with **changes** to a
(as in the mapper example above). This implementation depends on the list collection, and not the initial load from the database, and requires that the
starting in the proper order, so be SURE to put an order_by on your relationship. list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the
:func:`.relationship` against the target ordering attribute, so that the
ordering is correct when first loaded.
.. warning:: ``ordering_list`` only provides limited functionality when a primary .. warning::
key column or unique column is the target of the sort. Since changing the order of
entries often means that two rows must trade values, this is not possible when
the value is constrained by a primary key or unique constraint, since one of the rows
would temporarily have to point to a third available value so that the other row
could take its old value. ``ordering_list`` doesn't do any of this for you,
nor does SQLAlchemy itself.
``ordering_list`` takes the name of the related object's ordering attribute as :class:`.OrderingList` only provides limited functionality when a primary
an argument. By default, the zero-based integer index of the object's key column or unique column is the target of the sort. Operations
position in the ``ordering_list`` is synchronized with the ordering attribute: that are unsupported or are problematic include:
index 0 will get position 0, index 1 position 1, etc. To start numbering at 1
or some other integer, provide ``count_from=1``.
Ordering values are not limited to incrementing integers. Almost any scheme * two entries must trade values. This is not supported directly in the
can implemented by supplying a custom ``ordering_func`` that maps a Python list case of a primary key or unique constraint because it means at least
index to any value you require. one row would need to be temporarily removed first, or changed to
a third, neutral value while the switch occurs.
* an entry must be deleted in order to make room for a new entry.
SQLAlchemy's unit of work performs all INSERTs before DELETEs within a
single flush. In the case of a primary key, it will trade
an INSERT/DELETE of the same primary key for an UPDATE statement in order
to lessen the impact of this limitation, however this does not take place
for a UNIQUE column.
A future feature will allow the "DELETE before INSERT" behavior to be
possible, allevating this limitation, though this feature will require
explicit configuration at the mapper level for sets of columns that
are to be handled in this way.
:func:`.ordering_list` takes the name of the related object's ordering
attribute as an argument. By default, the zero-based integer index of the
object's position in the :func:`.ordering_list` is synchronized with the
ordering attribute: index 0 will get position 0, index 1 position 1, etc. To
start numbering at 1 or some other integer, provide ``count_from=1``.
""" """
from sqlalchemy.orm.collections import collection from ..orm.collections import collection, collection_adapter
from sqlalchemy import util from .. import util
__all__ = [ 'ordering_list' ] __all__ = ['ordering_list']
def ordering_list(attr, count_from=None, **kw): def ordering_list(attr, count_from=None, **kw):
"""Prepares an OrderingList factory for use in mapper definitions. """Prepares an :class:`OrderingList` factory for use in mapper definitions.
Returns an object suitable for use as an argument to a Mapper relationship's Returns an object suitable for use as an argument to a Mapper
``collection_class`` option. Arguments are: relationship's ``collection_class`` option. e.g.::
attr from sqlalchemy.ext.orderinglist import ordering_list
class Slide(Base):
__tablename__ = 'slide'
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship("Bullet", order_by="Bullet.position",
collection_class=ordering_list('position'))
:param attr:
Name of the mapped attribute to use for storage and retrieval of Name of the mapped attribute to use for storage and retrieval of
ordering information ordering information
count_from (optional) :param count_from:
Set up an integer-based ordering, starting at ``count_from``. For Set up an integer-based ordering, starting at ``count_from``. For
example, ``ordering_list('pos', count_from=1)`` would create a 1-based example, ``ordering_list('pos', count_from=1)`` would create a 1-based
list in SQL, storing the value in the 'pos' column. Ignored if list in SQL, storing the value in the 'pos' column. Ignored if
``ordering_func`` is supplied. ``ordering_func`` is supplied.
Passes along any keyword arguments to ``OrderingList`` constructor. Additional arguments are passed to the :class:`.OrderingList` constructor.
""" """
kw = _unsugar_count_from(count_from=count_from, **kw) kw = _unsugar_count_from(count_from=count_from, **kw)
return lambda: OrderingList(attr, **kw) return lambda: OrderingList(attr, **kw)
# Ordering utility functions # Ordering utility functions
def count_from_0(index, collection): def count_from_0(index, collection):
"""Numbering function: consecutive integers starting at 0.""" """Numbering function: consecutive integers starting at 0."""
return index return index
def count_from_1(index, collection): def count_from_1(index, collection):
"""Numbering function: consecutive integers starting at 1.""" """Numbering function: consecutive integers starting at 1."""
return index + 1 return index + 1
def count_from_n_factory(start): def count_from_n_factory(start):
"""Numbering function: consecutive integers starting at arbitrary start.""" """Numbering function: consecutive integers starting at arbitrary start."""
@ -139,8 +186,9 @@ def count_from_n_factory(start):
pass pass
return f return f
def _unsugar_count_from(**kw): def _unsugar_count_from(**kw):
"""Builds counting functions from keywrod arguments. """Builds counting functions from keyword arguments.
Keyword argument filter, prepares a simple ``ordering_func`` from a Keyword argument filter, prepares a simple ``ordering_func`` from a
``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
@ -156,12 +204,13 @@ def _unsugar_count_from(**kw):
kw['ordering_func'] = count_from_n_factory(count_from) kw['ordering_func'] = count_from_n_factory(count_from)
return kw return kw
class OrderingList(list): class OrderingList(list):
"""A custom list that manages position information for its children. """A custom list that manages position information for its children.
See the module and __init__ documentation for more details. The The :class:`.OrderingList` object is normally set up using the
``ordering_list`` factory function is used to configure ``OrderingList`` :func:`.ordering_list` factory function, used in conjunction with
collections in ``mapper`` relationship definitions. the :func:`.relationship` function.
""" """
@ -176,14 +225,14 @@ class OrderingList(list):
This implementation relies on the list starting in the proper order, This implementation relies on the list starting in the proper order,
so be **sure** to put an ``order_by`` on your relationship. so be **sure** to put an ``order_by`` on your relationship.
ordering_attr :param ordering_attr:
Name of the attribute that stores the object's order in the Name of the attribute that stores the object's order in the
relationship. relationship.
ordering_func :param ordering_func: Optional. A function that maps the position in
Optional. A function that maps the position in the Python list to a the Python list to a value to store in the
value to store in the ``ordering_attr``. Values returned are ``ordering_attr``. Values returned are usually (but need not be!)
usually (but need not be!) integers. integers.
An ``ordering_func`` is called with two positional parameters: the An ``ordering_func`` is called with two positional parameters: the
index of the element in the list, and the list itself. index of the element in the list, and the list itself.
@ -194,7 +243,7 @@ class OrderingList(list):
like stepped numbering, alphabetical and Fibonacci numbering, see like stepped numbering, alphabetical and Fibonacci numbering, see
the unit tests. the unit tests.
reorder_on_append :param reorder_on_append:
Default False. When appending an object with an existing (non-None) Default False. When appending an object with an existing (non-None)
ordering value, that value will be left untouched unless ordering value, that value will be left untouched unless
``reorder_on_append`` is true. This is an optimization to avoid a ``reorder_on_append`` is true. This is an optimization to avoid a
@ -208,7 +257,7 @@ class OrderingList(list):
making changes, any of whom happen to load this collection even in making changes, any of whom happen to load this collection even in
passing, all of the sessions would try to "clean up" the numbering passing, all of the sessions would try to "clean up" the numbering
in their commits, possibly causing all but one to fail with a in their commits, possibly causing all but one to fail with a
concurrent modification error. Spooky action at a distance. concurrent modification error.
Recommend leaving this with the default of False, and just call Recommend leaving this with the default of False, and just call
``reorder()`` if you're doing ``append()`` operations with ``reorder()`` if you're doing ``append()`` operations with
@ -270,7 +319,10 @@ class OrderingList(list):
def remove(self, entity): def remove(self, entity):
super(OrderingList, self).remove(entity) super(OrderingList, self).remove(entity)
self._reorder()
adapter = collection_adapter(self)
if adapter and adapter._referenced_by_owner:
self._reorder()
def pop(self, index=-1): def pop(self, index=-1):
entity = super(OrderingList, self).pop(index) entity = super(OrderingList, self).pop(index)
@ -286,8 +338,8 @@ class OrderingList(list):
stop = index.stop or len(self) stop = index.stop or len(self)
if stop < 0: if stop < 0:
stop += len(self) stop += len(self)
for i in xrange(start, stop, step): for i in range(start, stop, step):
self.__setitem__(i, entity[i]) self.__setitem__(i, entity[i])
else: else:
self._order_entity(index, entity, True) self._order_entity(index, entity, True)
@ -297,7 +349,6 @@ class OrderingList(list):
super(OrderingList, self).__delitem__(index) super(OrderingList, self).__delitem__(index)
self._reorder() self._reorder()
# Py2K
def __setslice__(self, start, end, values): def __setslice__(self, start, end, values):
super(OrderingList, self).__setslice__(start, end, values) super(OrderingList, self).__setslice__(start, end, values)
self._reorder() self._reorder()
@ -305,11 +356,25 @@ class OrderingList(list):
def __delslice__(self, start, end): def __delslice__(self, start, end):
super(OrderingList, self).__delslice__(start, end) super(OrderingList, self).__delslice__(start, end)
self._reorder() self._reorder()
# end Py2K
def __reduce__(self):
for func_name, func in locals().items(): return _reconstitute, (self.__class__, self.__dict__, list(self))
if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(list, func_name)): for func_name, func in list(locals().items()):
if (util.callable(func) and func.__name__ == func_name and
not func.__doc__ and hasattr(list, func_name)):
func.__doc__ = getattr(list, func_name).__doc__ func.__doc__ = getattr(list, func_name).__doc__
del func_name, func del func_name, func
def _reconstitute(cls, dict_, items):
""" Reconstitute an :class:`.OrderingList`.
This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for
unpickling :class:`.OrderingList` objects.
"""
obj = cls.__new__(cls)
obj.__dict__.update(dict_)
list.extend(obj, items)
return obj

View File

@ -1,4 +1,11 @@
"""Serializer/Deserializer objects for usage with SQLAlchemy query structures, # ext/serializer.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Serializer/Deserializer objects for usage with SQLAlchemy query structures,
allowing "contextual" deserialization. allowing "contextual" deserialization.
Any SQLAlchemy query structure, either based on sqlalchemy.sql.* Any SQLAlchemy query structure, either based on sqlalchemy.sql.*
@ -12,82 +19,73 @@ Usage is nearly the same as that of the standard Python pickle module::
from sqlalchemy.ext.serializer import loads, dumps from sqlalchemy.ext.serializer import loads, dumps
metadata = MetaData(bind=some_engine) metadata = MetaData(bind=some_engine)
Session = scoped_session(sessionmaker()) Session = scoped_session(sessionmaker())
# ... define mappers # ... define mappers
query = Session.query(MyClass).filter(MyClass.somedata=='foo').order_by(MyClass.sortkey) query = Session.query(MyClass).
filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
# pickle the query # pickle the query
serialized = dumps(query) serialized = dumps(query)
# unpickle. Pass in metadata + scoped_session # unpickle. Pass in metadata + scoped_session
query2 = loads(serialized, metadata, Session) query2 = loads(serialized, metadata, Session)
print query2.all() print query2.all()
Similar restrictions as when using raw pickle apply; mapped classes must be Similar restrictions as when using raw pickle apply; mapped classes must be
themselves be pickleable, meaning they are importable from a module-level themselves be pickleable, meaning they are importable from a module-level
namespace. namespace.
The serializer module is only appropriate for query structures. It is not The serializer module is only appropriate for query structures. It is not
needed for: needed for:
* instances of user-defined classes. These contain no references to engines, * instances of user-defined classes. These contain no references to engines,
sessions or expression constructs in the typical case and can be serialized directly. sessions or expression constructs in the typical case and can be serialized
directly.
* Table metadata that is to be loaded entirely from the serialized structure (i.e. is * Table metadata that is to be loaded entirely from the serialized structure
not already declared in the application). Regular pickle.loads()/dumps() can (i.e. is not already declared in the application). Regular
be used to fully dump any ``MetaData`` object, typically one which was reflected pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object,
from an existing database at some previous point in time. The serializer module typically one which was reflected from an existing database at some previous
is specifically for the opposite case, where the Table metadata is already present point in time. The serializer module is specifically for the opposite case,
in memory. where the Table metadata is already present in memory.
""" """
from sqlalchemy.orm import class_mapper, Query from ..orm import class_mapper
from sqlalchemy.orm.session import Session from ..orm.session import Session
from sqlalchemy.orm.mapper import Mapper from ..orm.mapper import Mapper
from sqlalchemy.orm.attributes import QueryableAttribute from ..orm.interfaces import MapperProperty
from sqlalchemy import Table, Column from ..orm.attributes import QueryableAttribute
from sqlalchemy.engine import Engine from .. import Table, Column
from sqlalchemy.util import pickle from ..engine import Engine
from ..util import pickle, byte_buffer, b64encode, b64decode, text_type
import re import re
import base64
# Py3K
#from io import BytesIO as byte_buffer
# Py2K
from cStringIO import StringIO as byte_buffer
# end Py2K
# Py3K
#def b64encode(x):
# return base64.b64encode(x).decode('ascii')
#def b64decode(x):
# return base64.b64decode(x.encode('ascii'))
# Py2K
b64encode = base64.b64encode
b64decode = base64.b64decode
# end Py2K
__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads'] __all__ = ['Serializer', 'Deserializer', 'dumps', 'loads']
def Serializer(*args, **kw): def Serializer(*args, **kw):
pickler = pickle.Pickler(*args, **kw) pickler = pickle.Pickler(*args, **kw)
def persistent_id(obj): def persistent_id(obj):
#print "serializing:", repr(obj) # print "serializing:", repr(obj)
if isinstance(obj, QueryableAttribute): if isinstance(obj, QueryableAttribute):
cls = obj.impl.class_ cls = obj.impl.class_
key = obj.impl.key key = obj.impl.key
id = "attribute:" + key + ":" + b64encode(pickle.dumps(cls)) id = "attribute:" + key + ":" + b64encode(pickle.dumps(cls))
elif isinstance(obj, Mapper) and not obj.non_primary: elif isinstance(obj, Mapper) and not obj.non_primary:
id = "mapper:" + b64encode(pickle.dumps(obj.class_)) id = "mapper:" + b64encode(pickle.dumps(obj.class_))
elif isinstance(obj, MapperProperty) and not obj.parent.non_primary:
id = "mapperprop:" + b64encode(pickle.dumps(obj.parent.class_)) + \
":" + obj.key
elif isinstance(obj, Table): elif isinstance(obj, Table):
id = "table:" + str(obj) id = "table:" + text_type(obj.key)
elif isinstance(obj, Column) and isinstance(obj.table, Table): elif isinstance(obj, Column) and isinstance(obj.table, Table):
id = "column:" + str(obj.table) + ":" + obj.key id = "column:" + \
text_type(obj.table.key) + ":" + text_type(obj.key)
elif isinstance(obj, Session): elif isinstance(obj, Session):
id = "session:" id = "session:"
elif isinstance(obj, Engine): elif isinstance(obj, Engine):
@ -95,15 +93,17 @@ def Serializer(*args, **kw):
else: else:
return None return None
return id return id
pickler.persistent_id = persistent_id pickler.persistent_id = persistent_id
return pickler return pickler
our_ids = re.compile(r'(mapper|table|column|session|attribute|engine):(.*)') our_ids = re.compile(
r'(mapperprop|mapper|table|column|session|attribute|engine):(.*)')
def Deserializer(file, metadata=None, scoped_session=None, engine=None): def Deserializer(file, metadata=None, scoped_session=None, engine=None):
unpickler = pickle.Unpickler(file) unpickler = pickle.Unpickler(file)
def get_engine(): def get_engine():
if engine: if engine:
return engine return engine
@ -113,9 +113,9 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
return metadata.bind return metadata.bind
else: else:
return None return None
def persistent_load(id): def persistent_load(id):
m = our_ids.match(id) m = our_ids.match(text_type(id))
if not m: if not m:
return None return None
else: else:
@ -127,6 +127,10 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
elif type_ == "mapper": elif type_ == "mapper":
cls = pickle.loads(b64decode(args)) cls = pickle.loads(b64decode(args))
return class_mapper(cls) return class_mapper(cls)
elif type_ == "mapperprop":
mapper, keyname = args.split(':')
cls = pickle.loads(b64decode(mapper))
return class_mapper(cls).attrs[keyname]
elif type_ == "table": elif type_ == "table":
return metadata.tables[args] return metadata.tables[args]
elif type_ == "column": elif type_ == "column":
@ -141,15 +145,15 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
unpickler.persistent_load = persistent_load unpickler.persistent_load = persistent_load
return unpickler return unpickler
def dumps(obj, protocol=0): def dumps(obj, protocol=0):
buf = byte_buffer() buf = byte_buffer()
pickler = Serializer(buf, protocol) pickler = Serializer(buf, protocol)
pickler.dump(obj) pickler.dump(obj)
return buf.getvalue() return buf.getvalue()
def loads(data, metadata=None, scoped_session=None, engine=None): def loads(data, metadata=None, scoped_session=None, engine=None):
buf = byte_buffer(data) buf = byte_buffer(data)
unpickler = Deserializer(buf, metadata, scoped_session, engine) unpickler = Deserializer(buf, metadata, scoped_session, engine)
return unpickler.load() return unpickler.load()

View File

@ -1,31 +1,45 @@
# interfaces.py # sqlalchemy/interfaces.py
# Copyright (C) 2007-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# Copyright (C) 2007 Jason Kirtland jek@discorporate.us # Copyright (C) 2007 Jason Kirtland jek@discorporate.us
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Interfaces and abstract types.""" """Deprecated core event interfaces.
This module is **deprecated** and is superseded by the
event system.
"""
from . import event, util
class PoolListener(object): class PoolListener(object):
"""Hooks into the lifecycle of connections in a ``Pool``. """Hooks into the lifecycle of connections in a :class:`.Pool`.
.. note::
:class:`.PoolListener` is deprecated. Please
refer to :class:`.PoolEvents`.
Usage:: Usage::
class MyListener(PoolListener): class MyListener(PoolListener):
def connect(self, dbapi_con, con_record): def connect(self, dbapi_con, con_record):
'''perform connect operations''' '''perform connect operations'''
# etc. # etc.
# create a new pool with a listener # create a new pool with a listener
p = QueuePool(..., listeners=[MyListener()]) p = QueuePool(..., listeners=[MyListener()])
# add a listener after the fact # add a listener after the fact
p.add_listener(MyListener()) p.add_listener(MyListener())
# usage with create_engine() # usage with create_engine()
e = create_engine("url://", listeners=[MyListener()]) e = create_engine("url://", listeners=[MyListener()])
All of the standard connection :class:`~sqlalchemy.pool.Pool` types can All of the standard connection :class:`~sqlalchemy.pool.Pool` types can
accept event listeners for key connection lifecycle events: accept event listeners for key connection lifecycle events:
creation, pool check-out and check-in. There are no events fired creation, pool check-out and check-in. There are no events fired
@ -56,9 +70,28 @@ class PoolListener(object):
internal event queues based on its capabilities. In terms of internal event queues based on its capabilities. In terms of
efficiency and function call overhead, you're much better off only efficiency and function call overhead, you're much better off only
providing implementations for the hooks you'll be using. providing implementations for the hooks you'll be using.
""" """
@classmethod
def _adapt_listener(cls, self, listener):
"""Adapt a :class:`.PoolListener` to individual
:class:`event.Dispatch` events.
"""
listener = util.as_interface(listener,
methods=('connect', 'first_connect',
'checkout', 'checkin'))
if hasattr(listener, 'connect'):
event.listen(self, 'connect', listener.connect)
if hasattr(listener, 'first_connect'):
event.listen(self, 'first_connect', listener.first_connect)
if hasattr(listener, 'checkout'):
event.listen(self, 'checkout', listener.checkout)
if hasattr(listener, 'checkin'):
event.listen(self, 'checkin', listener.checkin)
def connect(self, dbapi_con, con_record): def connect(self, dbapi_con, con_record):
"""Called once for each new DB-API connection or Pool's ``creator()``. """Called once for each new DB-API connection or Pool's ``creator()``.
@ -117,89 +150,163 @@ class PoolListener(object):
""" """
class ConnectionProxy(object): class ConnectionProxy(object):
"""Allows interception of statement execution by Connections. """Allows interception of statement execution by Connections.
.. note::
:class:`.ConnectionProxy` is deprecated. Please
refer to :class:`.ConnectionEvents`.
Either or both of the ``execute()`` and ``cursor_execute()`` Either or both of the ``execute()`` and ``cursor_execute()``
may be implemented to intercept compiled statement and may be implemented to intercept compiled statement and
cursor level executions, e.g.:: cursor level executions, e.g.::
class MyProxy(ConnectionProxy): class MyProxy(ConnectionProxy):
def execute(self, conn, execute, clauseelement, *multiparams, **params): def execute(self, conn, execute, clauseelement,
*multiparams, **params):
print "compiled statement:", clauseelement print "compiled statement:", clauseelement
return execute(clauseelement, *multiparams, **params) return execute(clauseelement, *multiparams, **params)
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): def cursor_execute(self, execute, cursor, statement,
parameters, context, executemany):
print "raw statement:", statement print "raw statement:", statement
return execute(cursor, statement, parameters, context) return execute(cursor, statement, parameters, context)
The ``execute`` argument is a function that will fulfill the default The ``execute`` argument is a function that will fulfill the default
execution behavior for the operation. The signature illustrated execution behavior for the operation. The signature illustrated
in the example should be used. in the example should be used.
The proxy is installed into an :class:`~sqlalchemy.engine.Engine` via The proxy is installed into an :class:`~sqlalchemy.engine.Engine` via
the ``proxy`` argument:: the ``proxy`` argument::
e = create_engine('someurl://', proxy=MyProxy()) e = create_engine('someurl://', proxy=MyProxy())
""" """
@classmethod
def _adapt_listener(cls, self, listener):
def adapt_execute(conn, clauseelement, multiparams, params):
def execute_wrapper(clauseelement, *multiparams, **params):
return clauseelement, multiparams, params
return listener.execute(conn, execute_wrapper,
clauseelement, *multiparams,
**params)
event.listen(self, 'before_execute', adapt_execute)
def adapt_cursor_execute(conn, cursor, statement,
parameters, context, executemany):
def execute_wrapper(
cursor,
statement,
parameters,
context,
):
return statement, parameters
return listener.cursor_execute(
execute_wrapper,
cursor,
statement,
parameters,
context,
executemany,
)
event.listen(self, 'before_cursor_execute', adapt_cursor_execute)
def do_nothing_callback(*arg, **kw):
pass
def adapt_listener(fn):
def go(conn, *arg, **kw):
fn(conn, do_nothing_callback, *arg, **kw)
return util.update_wrapper(go, fn)
event.listen(self, 'begin', adapt_listener(listener.begin))
event.listen(self, 'rollback',
adapt_listener(listener.rollback))
event.listen(self, 'commit', adapt_listener(listener.commit))
event.listen(self, 'savepoint',
adapt_listener(listener.savepoint))
event.listen(self, 'rollback_savepoint',
adapt_listener(listener.rollback_savepoint))
event.listen(self, 'release_savepoint',
adapt_listener(listener.release_savepoint))
event.listen(self, 'begin_twophase',
adapt_listener(listener.begin_twophase))
event.listen(self, 'prepare_twophase',
adapt_listener(listener.prepare_twophase))
event.listen(self, 'rollback_twophase',
adapt_listener(listener.rollback_twophase))
event.listen(self, 'commit_twophase',
adapt_listener(listener.commit_twophase))
def execute(self, conn, execute, clauseelement, *multiparams, **params): def execute(self, conn, execute, clauseelement, *multiparams, **params):
"""Intercept high level execute() events.""" """Intercept high level execute() events."""
return execute(clauseelement, *multiparams, **params) return execute(clauseelement, *multiparams, **params)
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): def cursor_execute(self, execute, cursor, statement, parameters,
context, executemany):
"""Intercept low-level cursor execute() events.""" """Intercept low-level cursor execute() events."""
return execute(cursor, statement, parameters, context) return execute(cursor, statement, parameters, context)
def begin(self, conn, begin): def begin(self, conn, begin):
"""Intercept begin() events.""" """Intercept begin() events."""
return begin() return begin()
def rollback(self, conn, rollback): def rollback(self, conn, rollback):
"""Intercept rollback() events.""" """Intercept rollback() events."""
return rollback() return rollback()
def commit(self, conn, commit): def commit(self, conn, commit):
"""Intercept commit() events.""" """Intercept commit() events."""
return commit() return commit()
def savepoint(self, conn, savepoint, name=None): def savepoint(self, conn, savepoint, name=None):
"""Intercept savepoint() events.""" """Intercept savepoint() events."""
return savepoint(name=name) return savepoint(name=name)
def rollback_savepoint(self, conn, rollback_savepoint, name, context): def rollback_savepoint(self, conn, rollback_savepoint, name, context):
"""Intercept rollback_savepoint() events.""" """Intercept rollback_savepoint() events."""
return rollback_savepoint(name, context) return rollback_savepoint(name, context)
def release_savepoint(self, conn, release_savepoint, name, context): def release_savepoint(self, conn, release_savepoint, name, context):
"""Intercept release_savepoint() events.""" """Intercept release_savepoint() events."""
return release_savepoint(name, context) return release_savepoint(name, context)
def begin_twophase(self, conn, begin_twophase, xid): def begin_twophase(self, conn, begin_twophase, xid):
"""Intercept begin_twophase() events.""" """Intercept begin_twophase() events."""
return begin_twophase(xid) return begin_twophase(xid)
def prepare_twophase(self, conn, prepare_twophase, xid): def prepare_twophase(self, conn, prepare_twophase, xid):
"""Intercept prepare_twophase() events.""" """Intercept prepare_twophase() events."""
return prepare_twophase(xid) return prepare_twophase(xid)
def rollback_twophase(self, conn, rollback_twophase, xid, is_prepared): def rollback_twophase(self, conn, rollback_twophase, xid, is_prepared):
"""Intercept rollback_twophase() events.""" """Intercept rollback_twophase() events."""
return rollback_twophase(xid, is_prepared) return rollback_twophase(xid, is_prepared)
def commit_twophase(self, conn, commit_twophase, xid, is_prepared): def commit_twophase(self, conn, commit_twophase, xid, is_prepared):
"""Intercept commit_twophase() events.""" """Intercept commit_twophase() events."""
return commit_twophase(xid, is_prepared) return commit_twophase(xid, is_prepared)

View File

@ -1,5 +1,7 @@
# log.py - adapt python logging module to SQLAlchemy # sqlalchemy/log.py
# Copyright (C) 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com # Copyright (C) 2006-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# Includes alterations by Vinay Sajip vinay_sajip@yahoo.co.uk
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -10,92 +12,189 @@ Control of logging for SA can be performed from the regular python logging
module. The regular dotted module namespace is used, starting at module. The regular dotted module namespace is used, starting at
'sqlalchemy'. For class-level logging, the class name is appended. 'sqlalchemy'. For class-level logging, the class name is appended.
The "echo" keyword parameter which is available on SQLA ``Engine`` The "echo" keyword parameter, available on SQLA :class:`.Engine`
and ``Pool`` objects corresponds to a logger specific to that and :class:`.Pool` objects, corresponds to a logger specific to that
instance only. instance only.
E.g.::
engine.echo = True
is equivalent to::
import logging
logger = logging.getLogger('sqlalchemy.engine.Engine.%s' % hex(id(engine)))
logger.setLevel(logging.DEBUG)
""" """
import logging import logging
import sys import sys
from sqlalchemy import util
# set initial level to WARN. This so that
# log statements don't occur in the absence of explicit
# logging being enabled for 'sqlalchemy'.
rootlogger = logging.getLogger('sqlalchemy') rootlogger = logging.getLogger('sqlalchemy')
if rootlogger.level == logging.NOTSET: if rootlogger.level == logging.NOTSET:
rootlogger.setLevel(logging.WARN) rootlogger.setLevel(logging.WARN)
default_enabled = False
def default_logging(name): def _add_default_handler(logger):
global default_enabled handler = logging.StreamHandler(sys.stdout)
if logging.getLogger(name).getEffectiveLevel() < logging.WARN: handler.setFormatter(logging.Formatter(
default_enabled = True '%(asctime)s %(levelname)s %(name)s %(message)s'))
if not default_enabled: logger.addHandler(handler)
default_enabled = True
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s %(name)s %(message)s'))
rootlogger.addHandler(handler)
_logged_classes = set() _logged_classes = set()
def class_logger(cls, enable=False):
def class_logger(cls):
logger = logging.getLogger(cls.__module__ + "." + cls.__name__) logger = logging.getLogger(cls.__module__ + "." + cls.__name__)
if enable == 'debug':
logger.setLevel(logging.DEBUG)
elif enable == 'info':
logger.setLevel(logging.INFO)
cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG) cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG)
cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO) cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO)
cls.logger = logger cls.logger = logger
_logged_classes.add(cls) _logged_classes.add(cls)
return cls
class Identified(object): class Identified(object):
@util.memoized_property logging_name = None
def logging_name(self):
# limit the number of loggers by chopping off the hex(id).
# some novice users unfortunately create an unlimited number
# of Engines in their applications which would otherwise
# cause the app to run out of memory.
return "0x...%s" % hex(id(self))[-4:]
def _should_log_debug(self):
def instance_logger(instance, echoflag=None): return self.logger.isEnabledFor(logging.DEBUG)
"""create a logger for an instance that implements :class:`Identified`.
def _should_log_info(self):
Warning: this is an expensive call which also results in a permanent return self.logger.isEnabledFor(logging.INFO)
increase in memory overhead for each call. Use only for
low-volume, long-time-spanning objects.
class InstanceLogger(object):
"""A logger adapter (wrapper) for :class:`.Identified` subclasses.
This allows multiple instances (e.g. Engine or Pool instances)
to share a logger, but have its verbosity controlled on a
per-instance basis.
The basic functionality is to return a logging level
which is based on an instance's echo setting.
Default implementation is:
'debug' -> logging.DEBUG
True -> logging.INFO
False -> Effective level of underlying logger
(logging.WARNING by default)
None -> same as False
""" """
name = "%s.%s.%s" % (instance.__class__.__module__, # Map echo settings to logger levels
instance.__class__.__name__, instance.logging_name) _echo_map = {
None: logging.NOTSET,
if echoflag is not None: False: logging.NOTSET,
l = logging.getLogger(name) True: logging.INFO,
if echoflag == 'debug': 'debug': logging.DEBUG,
default_logging(name) }
l.setLevel(logging.DEBUG)
elif echoflag is True: def __init__(self, echo, name):
default_logging(name) self.echo = echo
l.setLevel(logging.INFO) self.logger = logging.getLogger(name)
elif echoflag is False:
l.setLevel(logging.WARN) # if echo flag is enabled and no handlers,
# add a handler to the list
if self._echo_map[echo] <= logging.INFO \
and not self.logger.handlers:
_add_default_handler(self.logger)
#
# Boilerplate convenience methods
#
def debug(self, msg, *args, **kwargs):
"""Delegate a debug call to the underlying logger."""
self.log(logging.DEBUG, msg, *args, **kwargs)
def info(self, msg, *args, **kwargs):
"""Delegate an info call to the underlying logger."""
self.log(logging.INFO, msg, *args, **kwargs)
def warning(self, msg, *args, **kwargs):
"""Delegate a warning call to the underlying logger."""
self.log(logging.WARNING, msg, *args, **kwargs)
warn = warning
def error(self, msg, *args, **kwargs):
"""
Delegate an error call to the underlying logger.
"""
self.log(logging.ERROR, msg, *args, **kwargs)
def exception(self, msg, *args, **kwargs):
"""Delegate an exception call to the underlying logger."""
kwargs["exc_info"] = 1
self.log(logging.ERROR, msg, *args, **kwargs)
def critical(self, msg, *args, **kwargs):
"""Delegate a critical call to the underlying logger."""
self.log(logging.CRITICAL, msg, *args, **kwargs)
def log(self, level, msg, *args, **kwargs):
"""Delegate a log call to the underlying logger.
The level here is determined by the echo
flag as well as that of the underlying logger, and
logger._log() is called directly.
"""
# inline the logic from isEnabledFor(),
# getEffectiveLevel(), to avoid overhead.
if self.logger.manager.disable >= level:
return
selected_level = self._echo_map[self.echo]
if selected_level == logging.NOTSET:
selected_level = self.logger.getEffectiveLevel()
if level >= selected_level:
self.logger._log(level, msg, args, **kwargs)
def isEnabledFor(self, level):
"""Is this logger enabled for level 'level'?"""
if self.logger.manager.disable >= level:
return False
return level >= self.getEffectiveLevel()
def getEffectiveLevel(self):
"""What's the effective level for this logger?"""
level = self._echo_map[self.echo]
if level == logging.NOTSET:
level = self.logger.getEffectiveLevel()
return level
def instance_logger(instance, echoflag=None):
"""create a logger for an instance that implements :class:`.Identified`."""
if instance.logging_name:
name = "%s.%s.%s" % (instance.__class__.__module__,
instance.__class__.__name__,
instance.logging_name)
else: else:
l = logging.getLogger(name) name = "%s.%s" % (instance.__class__.__module__,
instance._should_log_debug = lambda: l.isEnabledFor(logging.DEBUG) instance.__class__.__name__)
instance._should_log_info = lambda: l.isEnabledFor(logging.INFO)
return l instance._echo = echoflag
if echoflag in (False, None):
# if no echo setting or False, return a Logger directly,
# avoiding overhead of filtering
logger = logging.getLogger(name)
else:
# if a specified echo flag, return an EchoLogger,
# which checks the flag, overrides normal log
# levels by calling logger._log()
logger = InstanceLogger(echoflag, name)
instance.logger = logger
class echo_property(object): class echo_property(object):
__doc__ = """\ __doc__ = """\
@ -112,8 +211,7 @@ class echo_property(object):
if instance is None: if instance is None:
return self return self
else: else:
return instance._should_log_debug() and 'debug' or \ return instance._echo
(instance._should_log_info() and True or False)
def __set__(self, instance, value): def __set__(self, instance, value):
instance_logger(instance, echoflag=value) instance_logger(instance, echoflag=value)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
# dynamic.py # orm/dynamic.py
# Copyright (C) the SQLAlchemy authors and contributors # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -11,42 +12,47 @@ basic add/delete mutation.
""" """
from sqlalchemy import log, util from .. import log, util, exc
from sqlalchemy import exc as sa_exc from ..sql import operators
from sqlalchemy.orm import exc as sa_exc from . import (
from sqlalchemy.sql import operators attributes, object_session, util as orm_util, strategies,
from sqlalchemy.orm import ( object_mapper, exc as orm_exc, properties
attributes, object_session, util as mapperutil, strategies, object_mapper )
) from .query import Query
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.util import _state_has_identity, has_identity
from sqlalchemy.orm import attributes, collections
@log.class_logger
@properties.RelationshipProperty.strategy_for(lazy="dynamic")
class DynaLoader(strategies.AbstractRelationshipLoader): class DynaLoader(strategies.AbstractRelationshipLoader):
def init_class_attribute(self, mapper): def init_class_attribute(self, mapper):
self.is_class_level = True self.is_class_level = True
if not self.uselist:
strategies._register_attribute(self, raise exc.InvalidRequestError(
"On relationship %s, 'dynamic' loaders cannot be used with "
"many-to-one/one-to-one relationships and/or "
"uselist=False." % self.parent_property)
strategies._register_attribute(
self.parent_property,
mapper, mapper,
useobject=True, useobject=True,
impl_class=DynamicAttributeImpl, impl_class=DynamicAttributeImpl,
target_mapper=self.parent_property.mapper, target_mapper=self.parent_property.mapper,
order_by=self.parent_property.order_by, order_by=self.parent_property.order_by,
query_class=self.parent_property.query_class query_class=self.parent_property.query_class,
) )
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
return (None, None)
log.class_logger(DynaLoader)
class DynamicAttributeImpl(attributes.AttributeImpl): class DynamicAttributeImpl(attributes.AttributeImpl):
uses_objects = True uses_objects = True
accepts_scalar_loader = False accepts_scalar_loader = False
supports_population = False
collection = False
def __init__(self, class_, key, typecallable, def __init__(self, class_, key, typecallable,
target_mapper, order_by, query_class=None, **kwargs): dispatch,
super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, **kwargs) target_mapper, order_by, query_class=None, **kw):
super(DynamicAttributeImpl, self).\
__init__(class_, key, typecallable, dispatch, **kw)
self.target_mapper = target_mapper self.target_mapper = target_mapper
self.order_by = order_by self.order_by = order_by
if not query_class: if not query_class:
@ -56,178 +62,204 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
else: else:
self.query_class = mixin_user_query(query_class) self.query_class = mixin_user_query(query_class)
def get(self, state, dict_, passive=False): def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
if passive: if not passive & attributes.SQL_OK:
return self._get_collection_history(state, passive=True).added_items return self._get_collection_history(
state, attributes.PASSIVE_NO_INITIALIZE).added_items
else: else:
return self.query_class(self, state) return self.query_class(self, state)
def get_collection(self, state, dict_, user_data=None, passive=True): def get_collection(self, state, dict_, user_data=None,
if passive: passive=attributes.PASSIVE_NO_INITIALIZE):
return self._get_collection_history(state, passive=passive).added_items if not passive & attributes.SQL_OK:
return self._get_collection_history(state,
passive).added_items
else: else:
history = self._get_collection_history(state, passive=passive) history = self._get_collection_history(state, passive)
return history.added_items + history.unchanged_items return history.added_plus_unchanged
def fire_append_event(self, state, dict_, value, initiator): @util.memoized_property
collection_history = self._modified_event(state, dict_) def _append_token(self):
collection_history.added_items.append(value) return attributes.Event(self, attributes.OP_APPEND)
for ext in self.extensions: @util.memoized_property
ext.append(state, value, initiator or self) def _remove_token(self):
return attributes.Event(self, attributes.OP_REMOVE)
def fire_append_event(self, state, dict_, value, initiator,
collection_history=None):
if collection_history is None:
collection_history = self._modified_event(state, dict_)
collection_history.add_added(value)
for fn in self.dispatch.append:
value = fn(state, value, initiator or self._append_token)
if self.trackparent and value is not None: if self.trackparent and value is not None:
self.sethasparent(attributes.instance_state(value), True) self.sethasparent(attributes.instance_state(value), state, True)
def fire_remove_event(self, state, dict_, value, initiator): def fire_remove_event(self, state, dict_, value, initiator,
collection_history = self._modified_event(state, dict_) collection_history=None):
collection_history.deleted_items.append(value) if collection_history is None:
collection_history = self._modified_event(state, dict_)
collection_history.add_removed(value)
if self.trackparent and value is not None: if self.trackparent and value is not None:
self.sethasparent(attributes.instance_state(value), False) self.sethasparent(attributes.instance_state(value), state, False)
for ext in self.extensions: for fn in self.dispatch.remove:
ext.remove(state, value, initiator or self) fn(state, value, initiator or self._remove_token)
def _modified_event(self, state, dict_): def _modified_event(self, state, dict_):
if self.key not in state.committed_state: if self.key not in state.committed_state:
state.committed_state[self.key] = CollectionHistory(self, state) state.committed_state[self.key] = CollectionHistory(self, state)
state.modified_event(dict_, state._modified_event(dict_,
self, self,
False, attributes.NEVER_SET)
attributes.NEVER_SET,
passive=attributes.PASSIVE_NO_INITIALIZE)
# this is a hack to allow the _base.ComparableEntity fixture # this is a hack to allow the fixtures.ComparableEntity fixture
# to work # to work
dict_[self.key] = True dict_[self.key] = True
return state.committed_state[self.key] return state.committed_state[self.key]
def set(self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF): def set(self, state, dict_, value, initiator=None,
if initiator is self: passive=attributes.PASSIVE_OFF,
check_old=None, pop=False, _adapt=True):
if initiator and initiator.parent_token is self.parent_token:
return return
self._set_iterable(state, dict_, value) if pop and value is None:
return
def _set_iterable(self, state, dict_, iterable, adapter=None): iterable = value
new_values = list(iterable)
if state.has_identity:
old_collection = util.IdentitySet(self.get(state, dict_))
collection_history = self._modified_event(state, dict_) collection_history = self._modified_event(state, dict_)
new_values = list(iterable) if not state.has_identity:
old_collection = collection_history.added_items
if _state_has_identity(state):
old_collection = list(self.get(state, dict_))
else: else:
old_collection = [] old_collection = old_collection.union(
collection_history.added_items)
collections.bulk_replace(new_values, DynCollectionAdapter(self, state, old_collection), DynCollectionAdapter(self, state, new_values)) idset = util.IdentitySet
constants = old_collection.intersection(new_values)
additions = idset(new_values).difference(constants)
removals = old_collection.difference(constants)
for member in new_values:
if member in additions:
self.fire_append_event(state, dict_, member, None,
collection_history=collection_history)
for member in removals:
self.fire_remove_event(state, dict_, member, None,
collection_history=collection_history)
def delete(self, *args, **kwargs): def delete(self, *args, **kwargs):
raise NotImplementedError() raise NotImplementedError()
def get_history(self, state, dict_, passive=False): def set_committed_value(self, state, dict_, value):
c = self._get_collection_history(state, passive) raise NotImplementedError("Dynamic attributes don't support "
return attributes.History(c.added_items, c.unchanged_items, c.deleted_items) "collection population.")
def _get_collection_history(self, state, passive=False): def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
c = self._get_collection_history(state, passive)
return c.as_history()
def get_all_pending(self, state, dict_,
passive=attributes.PASSIVE_NO_INITIALIZE):
c = self._get_collection_history(
state, passive)
return [
(attributes.instance_state(x), x)
for x in
c.all_items
]
def _get_collection_history(self, state, passive=attributes.PASSIVE_OFF):
if self.key in state.committed_state: if self.key in state.committed_state:
c = state.committed_state[self.key] c = state.committed_state[self.key]
else: else:
c = CollectionHistory(self, state) c = CollectionHistory(self, state)
if not passive: if state.has_identity and (passive & attributes.INIT_OK):
return CollectionHistory(self, state, apply_to=c) return CollectionHistory(self, state, apply_to=c)
else: else:
return c return c
def append(self, state, dict_, value, initiator, passive=False): def append(self, state, dict_, value, initiator,
passive=attributes.PASSIVE_OFF):
if initiator is not self: if initiator is not self:
self.fire_append_event(state, dict_, value, initiator) self.fire_append_event(state, dict_, value, initiator)
def remove(self, state, dict_, value, initiator, passive=False): def remove(self, state, dict_, value, initiator,
passive=attributes.PASSIVE_OFF):
if initiator is not self: if initiator is not self:
self.fire_remove_event(state, dict_, value, initiator) self.fire_remove_event(state, dict_, value, initiator)
class DynCollectionAdapter(object): def pop(self, state, dict_, value, initiator,
"""the dynamic analogue to orm.collections.CollectionAdapter""" passive=attributes.PASSIVE_OFF):
self.remove(state, dict_, value, initiator, passive=passive)
def __init__(self, attr, owner_state, data):
self.attr = attr
self.state = owner_state
self.data = data
def __iter__(self):
return iter(self.data)
def append_with_event(self, item, initiator=None):
self.attr.append(self.state, self.state.dict, item, initiator)
def remove_with_event(self, item, initiator=None):
self.attr.remove(self.state, self.state.dict, item, initiator)
def append_without_event(self, item):
pass
def remove_without_event(self, item):
pass
class AppenderMixin(object): class AppenderMixin(object):
query_class = None query_class = None
def __init__(self, attr, state): def __init__(self, attr, state):
Query.__init__(self, attr.target_mapper, None) super(AppenderMixin, self).__init__(attr.target_mapper, None)
self.instance = instance = state.obj() self.instance = instance = state.obj()
self.attr = attr self.attr = attr
mapper = object_mapper(instance) mapper = object_mapper(instance)
prop = mapper.get_property(self.attr.key, resolve_synonyms=True) prop = mapper._props[self.attr.key]
self._criterion = prop.compare( self._criterion = prop._with_parent(
operators.eq, instance,
instance, alias_secondary=False)
value_is_parent=True,
alias_secondary=False)
if self.attr.order_by: if self.attr.order_by:
self._order_by = self.attr.order_by self._order_by = self.attr.order_by
def __session(self): def session(self):
sess = object_session(self.instance) sess = object_session(self.instance)
if sess is not None and self.autoflush and sess.autoflush and self.instance in sess: if sess is not None and self.autoflush and sess.autoflush \
and self.instance in sess:
sess.flush() sess.flush()
if not has_identity(self.instance): if not orm_util.has_identity(self.instance):
return None return None
else: else:
return sess return sess
session = property(session, lambda s, x: None)
def session(self):
return self.__session()
session = property(session, lambda s, x:None)
def __iter__(self): def __iter__(self):
sess = self.__session() sess = self.session
if sess is None: if sess is None:
return iter(self.attr._get_collection_history( return iter(self.attr._get_collection_history(
attributes.instance_state(self.instance), attributes.instance_state(self.instance),
passive=True).added_items) attributes.PASSIVE_NO_INITIALIZE).added_items)
else: else:
return iter(self._clone(sess)) return iter(self._clone(sess))
def __getitem__(self, index): def __getitem__(self, index):
sess = self.__session() sess = self.session
if sess is None: if sess is None:
return self.attr._get_collection_history( return self.attr._get_collection_history(
attributes.instance_state(self.instance), attributes.instance_state(self.instance),
passive=True).added_items.__getitem__(index) attributes.PASSIVE_NO_INITIALIZE).indexed(index)
else: else:
return self._clone(sess).__getitem__(index) return self._clone(sess).__getitem__(index)
def count(self): def count(self):
sess = self.__session() sess = self.session
if sess is None: if sess is None:
return len(self.attr._get_collection_history( return len(self.attr._get_collection_history(
attributes.instance_state(self.instance), attributes.instance_state(self.instance),
passive=True).added_items) attributes.PASSIVE_NO_INITIALIZE).added_items)
else: else:
return self._clone(sess).count() return self._clone(sess).count()
@ -243,26 +275,32 @@ class AppenderMixin(object):
"Parent instance %s is not bound to a Session, and no " "Parent instance %s is not bound to a Session, and no "
"contextual session is established; lazy load operation " "contextual session is established; lazy load operation "
"of attribute '%s' cannot proceed" % ( "of attribute '%s' cannot proceed" % (
mapperutil.instance_str(instance), self.attr.key)) orm_util.instance_str(instance), self.attr.key))
if self.query_class: if self.query_class:
query = self.query_class(self.attr.target_mapper, session=sess) query = self.query_class(self.attr.target_mapper, session=sess)
else: else:
query = sess.query(self.attr.target_mapper) query = sess.query(self.attr.target_mapper)
query._criterion = self._criterion query._criterion = self._criterion
query._order_by = self._order_by query._order_by = self._order_by
return query return query
def extend(self, iterator):
for item in iterator:
self.attr.append(
attributes.instance_state(self.instance),
attributes.instance_dict(self.instance), item, None)
def append(self, item): def append(self, item):
self.attr.append( self.attr.append(
attributes.instance_state(self.instance), attributes.instance_state(self.instance),
attributes.instance_dict(self.instance), item, None) attributes.instance_dict(self.instance), item, None)
def remove(self, item): def remove(self, item):
self.attr.remove( self.attr.remove(
attributes.instance_state(self.instance), attributes.instance_state(self.instance),
attributes.instance_dict(self.instance), item, None) attributes.instance_dict(self.instance), item, None)
@ -275,19 +313,55 @@ def mixin_user_query(cls):
name = 'Appender' + cls.__name__ name = 'Appender' + cls.__name__
return type(name, (AppenderMixin, cls), {'query_class': cls}) return type(name, (AppenderMixin, cls), {'query_class': cls})
class CollectionHistory(object): class CollectionHistory(object):
"""Overrides AttributeHistory to receive append/remove events directly.""" """Overrides AttributeHistory to receive append/remove events directly."""
def __init__(self, attr, state, apply_to=None): def __init__(self, attr, state, apply_to=None):
if apply_to: if apply_to:
deleted = util.IdentitySet(apply_to.deleted_items)
added = apply_to.added_items
coll = AppenderQuery(attr, state).autoflush(False) coll = AppenderQuery(attr, state).autoflush(False)
self.unchanged_items = [o for o in util.IdentitySet(coll) if o not in deleted] self.unchanged_items = util.OrderedIdentitySet(coll)
self.added_items = apply_to.added_items self.added_items = apply_to.added_items
self.deleted_items = apply_to.deleted_items self.deleted_items = apply_to.deleted_items
self._reconcile_collection = True
else: else:
self.deleted_items = [] self.deleted_items = util.OrderedIdentitySet()
self.added_items = [] self.added_items = util.OrderedIdentitySet()
self.unchanged_items = [] self.unchanged_items = util.OrderedIdentitySet()
self._reconcile_collection = False
@property
def added_plus_unchanged(self):
return list(self.added_items.union(self.unchanged_items))
@property
def all_items(self):
return list(self.added_items.union(
self.unchanged_items).union(self.deleted_items))
def as_history(self):
if self._reconcile_collection:
added = self.added_items.difference(self.unchanged_items)
deleted = self.deleted_items.intersection(self.unchanged_items)
unchanged = self.unchanged_items.difference(deleted)
else:
added, unchanged, deleted = self.added_items,\
self.unchanged_items,\
self.deleted_items
return attributes.History(
list(added),
list(unchanged),
list(deleted),
)
def indexed(self, index):
return list(self.added_items)[index]
def add_added(self, value):
self.added_items.add(value)
def add_removed(self, value):
if value in self.added_items:
self.added_items.remove(value)
else:
self.deleted_items.add(value)

View File

@ -1,17 +1,21 @@
# orm/evaluator.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import operator import operator
from sqlalchemy.sql import operators, functions from ..sql import operators
from sqlalchemy.sql import expression as sql
class UnevaluatableError(Exception): class UnevaluatableError(Exception):
pass pass
_straight_ops = set(getattr(operators, op) _straight_ops = set(getattr(operators, op)
for op in ('add', 'mul', 'sub', for op in ('add', 'mul', 'sub',
# Py2K 'div',
'div', 'mod', 'truediv',
# end Py2K
'mod', 'truediv',
'lt', 'le', 'ne', 'gt', 'ge', 'eq')) 'lt', 'le', 'ne', 'gt', 'ge', 'eq'))
@ -20,11 +24,16 @@ _notimplemented_ops = set(getattr(operators, op)
'notilike_op', 'between_op', 'in_op', 'notilike_op', 'between_op', 'in_op',
'notin_op', 'endswith_op', 'concat_op')) 'notin_op', 'endswith_op', 'concat_op'))
class EvaluatorCompiler(object): class EvaluatorCompiler(object):
def __init__(self, target_cls=None):
self.target_cls = target_cls
def process(self, clause): def process(self, clause):
meth = getattr(self, "visit_%s" % clause.__visit_name__, None) meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
if not meth: if not meth:
raise UnevaluatableError("Cannot evaluate %s" % type(clause).__name__) raise UnevaluatableError(
"Cannot evaluate %s" % type(clause).__name__)
return meth(clause) return meth(clause)
def visit_grouping(self, clause): def visit_grouping(self, clause):
@ -33,16 +42,30 @@ class EvaluatorCompiler(object):
def visit_null(self, clause): def visit_null(self, clause):
return lambda obj: None return lambda obj: None
def visit_false(self, clause):
return lambda obj: False
def visit_true(self, clause):
return lambda obj: True
def visit_column(self, clause): def visit_column(self, clause):
if 'parentmapper' in clause._annotations: if 'parentmapper' in clause._annotations:
key = clause._annotations['parentmapper']._get_col_to_prop(clause).key parentmapper = clause._annotations['parentmapper']
if self.target_cls and not issubclass(
self.target_cls, parentmapper.class_):
raise UnevaluatableError(
"Can't evaluate criteria against alternate class %s" %
parentmapper.class_
)
key = parentmapper._columntoproperty[clause].key
else: else:
key = clause.key key = clause.key
get_corresponding_attr = operator.attrgetter(key) get_corresponding_attr = operator.attrgetter(key)
return lambda obj: get_corresponding_attr(obj) return lambda obj: get_corresponding_attr(obj)
def visit_clauselist(self, clause): def visit_clauselist(self, clause):
evaluators = map(self.process, clause.clauses) evaluators = list(map(self.process, clause.clauses))
if clause.operator is operators.or_: if clause.operator is operators.or_:
def evaluate(obj): def evaluate(obj):
has_null = False has_null = False
@ -64,12 +87,15 @@ class EvaluatorCompiler(object):
return False return False
return True return True
else: else:
raise UnevaluatableError("Cannot evaluate clauselist with operator %s" % clause.operator) raise UnevaluatableError(
"Cannot evaluate clauselist with operator %s" %
clause.operator)
return evaluate return evaluate
def visit_binary(self, clause): def visit_binary(self, clause):
eval_left,eval_right = map(self.process, [clause.left, clause.right]) eval_left, eval_right = list(map(self.process,
[clause.left, clause.right]))
operator = clause.operator operator = clause.operator
if operator is operators.is_: if operator is operators.is_:
def evaluate(obj): def evaluate(obj):
@ -85,7 +111,9 @@ class EvaluatorCompiler(object):
return None return None
return operator(eval_left(obj), eval_right(obj)) return operator(eval_left(obj), eval_right(obj))
else: else:
raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator)) raise UnevaluatableError(
"Cannot evaluate %s with operator %s" %
(type(clause).__name__, clause.operator))
return evaluate return evaluate
def visit_unary(self, clause): def visit_unary(self, clause):
@ -97,8 +125,13 @@ class EvaluatorCompiler(object):
return None return None
return not value return not value
return evaluate return evaluate
raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator)) raise UnevaluatableError(
"Cannot evaluate %s with operator %s" %
(type(clause).__name__, clause.operator))
def visit_bindparam(self, clause): def visit_bindparam(self, clause):
val = clause.value if clause.callable:
val = clause.callable()
else:
val = clause.value
return lambda obj: val return lambda obj: val

View File

@ -1,42 +1,79 @@
# exc.py - ORM exceptions # orm/exc.py
# Copyright (C) the SQLAlchemy authors and contributors # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
"""SQLAlchemy ORM exceptions.""" """SQLAlchemy ORM exceptions."""
from .. import exc as sa_exc, util
import sqlalchemy as sa
NO_STATE = (AttributeError, KeyError) NO_STATE = (AttributeError, KeyError)
"""Exception types that may be raised by instrumentation implementations.""" """Exception types that may be raised by instrumentation implementations."""
class ConcurrentModificationError(sa.exc.SQLAlchemyError):
"""Rows have been modified outside of the unit of work.""" class StaleDataError(sa_exc.SQLAlchemyError):
"""An operation encountered database state that is unaccounted for.
Conditions which cause this to happen include:
* A flush may have attempted to update or delete rows
and an unexpected number of rows were matched during
the UPDATE or DELETE statement. Note that when
version_id_col is used, rows in UPDATE or DELETE statements
are also matched against the current known version
identifier.
* A mapped object with version_id_col was refreshed,
and the version number coming back from the database does
not match that of the object itself.
* A object is detached from its parent object, however
the object was previously attached to a different parent
identity which was garbage collected, and a decision
cannot be made if the new parent was really the most
recent "parent".
.. versionadded:: 0.7.4
"""
ConcurrentModificationError = StaleDataError
class FlushError(sa.exc.SQLAlchemyError): class FlushError(sa_exc.SQLAlchemyError):
"""A invalid condition was detected during flush().""" """A invalid condition was detected during flush()."""
class UnmappedError(sa.exc.InvalidRequestError): class UnmappedError(sa_exc.InvalidRequestError):
"""TODO""" """Base for exceptions that involve expected mappings not present."""
class ObjectDereferencedError(sa_exc.SQLAlchemyError):
"""An operation cannot complete due to an object being garbage
collected.
"""
class DetachedInstanceError(sa_exc.SQLAlchemyError):
"""An attempt to access unloaded attributes on a
mapped instance that is detached."""
class DetachedInstanceError(sa.exc.SQLAlchemyError):
"""An attempt to access unloaded attributes on a mapped instance that is detached."""
class UnmappedInstanceError(UnmappedError): class UnmappedInstanceError(UnmappedError):
"""An mapping operation was requested for an unknown instance.""" """An mapping operation was requested for an unknown instance."""
def __init__(self, obj, msg=None): @util.dependencies("sqlalchemy.orm.base")
def __init__(self, base, obj, msg=None):
if not msg: if not msg:
try: try:
mapper = sa.orm.class_mapper(type(obj)) base.class_mapper(type(obj))
name = _safe_cls_name(type(obj)) name = _safe_cls_name(type(obj))
msg = ("Class %r is mapped, but this instance lacks " msg = ("Class %r is mapped, but this instance lacks "
"instrumentation. This occurs when the instance is created " "instrumentation. This occurs when the instance"
"before sqlalchemy.orm.mapper(%s) was called." % (name, name)) "is created before sqlalchemy.orm.mapper(%s) "
"was called." % (name, name))
except UnmappedClassError: except UnmappedClassError:
msg = _default_unmapped(type(obj)) msg = _default_unmapped(type(obj))
if isinstance(obj, type): if isinstance(obj, type):
@ -45,6 +82,9 @@ class UnmappedInstanceError(UnmappedError):
'required?' % _safe_cls_name(obj)) 'required?' % _safe_cls_name(obj))
UnmappedError.__init__(self, msg) UnmappedError.__init__(self, msg)
def __reduce__(self):
return self.__class__, (None, self.args[0])
class UnmappedClassError(UnmappedError): class UnmappedClassError(UnmappedError):
"""An mapping operation was requested for an unknown class.""" """An mapping operation was requested for an unknown class."""
@ -54,28 +94,53 @@ class UnmappedClassError(UnmappedError):
msg = _default_unmapped(cls) msg = _default_unmapped(cls)
UnmappedError.__init__(self, msg) UnmappedError.__init__(self, msg)
def __reduce__(self):
class ObjectDeletedError(sa.exc.InvalidRequestError): return self.__class__, (None, self.args[0])
"""An refresh() operation failed to re-retrieve an object's row."""
class UnmappedColumnError(sa.exc.InvalidRequestError): class ObjectDeletedError(sa_exc.InvalidRequestError):
"""A refresh operation failed to retrieve the database
row corresponding to an object's known primary key identity.
A refresh operation proceeds when an expired attribute is
accessed on an object, or when :meth:`.Query.get` is
used to retrieve an object which is, upon retrieval, detected
as expired. A SELECT is emitted for the target row
based on primary key; if no row is returned, this
exception is raised.
The true meaning of this exception is simply that
no row exists for the primary key identifier associated
with a persistent object. The row may have been
deleted, or in some cases the primary key updated
to a new value, outside of the ORM's management of the target
object.
"""
@util.dependencies("sqlalchemy.orm.base")
def __init__(self, base, state, msg=None):
if not msg:
msg = "Instance '%s' has been deleted, or its "\
"row is otherwise not present." % base.state_str(state)
sa_exc.InvalidRequestError.__init__(self, msg)
def __reduce__(self):
return self.__class__, (None, self.args[0])
class UnmappedColumnError(sa_exc.InvalidRequestError):
"""Mapping operation was requested on an unknown column.""" """Mapping operation was requested on an unknown column."""
class NoResultFound(sa.exc.InvalidRequestError): class NoResultFound(sa_exc.InvalidRequestError):
"""A database result was required but none was found.""" """A database result was required but none was found."""
class MultipleResultsFound(sa.exc.InvalidRequestError): class MultipleResultsFound(sa_exc.InvalidRequestError):
"""A single database result was required but more than one were found.""" """A single database result was required but more than one were found."""
# Legacy compat until 0.6.
sa.exc.ConcurrentModificationError = ConcurrentModificationError
sa.exc.FlushError = FlushError
sa.exc.UnmappedColumnError
def _safe_cls_name(cls): def _safe_cls_name(cls):
try: try:
cls_name = '.'.join((cls.__module__, cls.__name__)) cls_name = '.'.join((cls.__module__, cls.__name__))
@ -85,9 +150,11 @@ def _safe_cls_name(cls):
cls_name = repr(cls) cls_name = repr(cls)
return cls_name return cls_name
def _default_unmapped(cls):
@util.dependencies("sqlalchemy.orm.base")
def _default_unmapped(base, cls):
try: try:
mappers = sa.orm.attributes.manager_of_class(cls).mappers mappers = base.manager_of_class(cls).mappers
except NO_STATE: except NO_STATE:
mappers = {} mappers = {}
except TypeError: except TypeError:

View File

@ -1,67 +1,66 @@
# identity.py # orm/identity.py
# Copyright (C) the SQLAlchemy authors and contributors # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
import weakref import weakref
from . import attributes
from .. import util
from .. import exc as sa_exc
from . import util as orm_util
from sqlalchemy import util as base_util class IdentityMap(object):
from sqlalchemy.orm import attributes
class IdentityMap(dict):
def __init__(self): def __init__(self):
self._mutable_attrs = set() self._dict = {}
self._modified = set() self._modified = set()
self._wr = weakref.ref(self) self._wr = weakref.ref(self)
def keys(self):
return self._dict.keys()
def replace(self, state): def replace(self, state):
raise NotImplementedError() raise NotImplementedError()
def add(self, state): def add(self, state):
raise NotImplementedError() raise NotImplementedError()
def remove(self, state): def _add_unpresent(self, state, key):
raise NotImplementedError() """optional inlined form of add() which can assume item isn't present
in the map"""
self.add(state)
def update(self, dict): def update(self, dict):
raise NotImplementedError("IdentityMap uses add() to insert data") raise NotImplementedError("IdentityMap uses add() to insert data")
def clear(self): def clear(self):
raise NotImplementedError("IdentityMap uses remove() to remove data") raise NotImplementedError("IdentityMap uses remove() to remove data")
def _manage_incoming_state(self, state): def _manage_incoming_state(self, state):
state._instance_dict = self._wr state._instance_dict = self._wr
if state.modified: if state.modified:
self._modified.add(state) self._modified.add(state)
if state.manager.mutable_attributes:
self._mutable_attrs.add(state)
def _manage_removed_state(self, state): def _manage_removed_state(self, state):
del state._instance_dict del state._instance_dict
self._mutable_attrs.discard(state) if state.modified:
self._modified.discard(state) self._modified.discard(state)
def _dirty_states(self): def _dirty_states(self):
return self._modified.union(s for s in self._mutable_attrs.copy() return self._modified
if s.modified)
def check_modified(self): def check_modified(self):
"""return True if any InstanceStates present have been marked as 'modified'.""" """return True if any InstanceStates present have been marked
as 'modified'.
if self._modified:
return True """
else: return bool(self._modified)
for state in self._mutable_attrs.copy():
if state.modified:
return True
return False
def has_key(self, key): def has_key(self, key):
return key in self return key in self
def popitem(self): def popitem(self):
raise NotImplementedError("IdentityMap uses remove() to remove data") raise NotImplementedError("IdentityMap uses remove() to remove data")
@ -71,6 +70,9 @@ class IdentityMap(dict):
def setdefault(self, key, default=None): def setdefault(self, key, default=None):
raise NotImplementedError("IdentityMap uses add() to insert data") raise NotImplementedError("IdentityMap uses add() to insert data")
def __len__(self):
return len(self._dict)
def copy(self): def copy(self):
raise NotImplementedError() raise NotImplementedError()
@ -79,164 +81,233 @@ class IdentityMap(dict):
def __delitem__(self, key): def __delitem__(self, key):
raise NotImplementedError("IdentityMap uses remove() to remove data") raise NotImplementedError("IdentityMap uses remove() to remove data")
class WeakInstanceDict(IdentityMap): class WeakInstanceDict(IdentityMap):
def __getitem__(self, key): def __getitem__(self, key):
state = dict.__getitem__(self, key) state = self._dict[key]
o = state.obj() o = state.obj()
if o is None: if o is None:
o = state._is_really_none() raise KeyError(key)
if o is None:
raise KeyError, key
return o return o
def __contains__(self, key): def __contains__(self, key):
try: try:
if dict.__contains__(self, key): if key in self._dict:
state = dict.__getitem__(self, key) state = self._dict[key]
o = state.obj() o = state.obj()
if o is None:
o = state._is_really_none()
else: else:
return False return False
except KeyError: except KeyError:
return False return False
else: else:
return o is not None return o is not None
def contains_state(self, state): def contains_state(self, state):
return dict.get(self, state.key) is state return state.key in self._dict and self._dict[state.key] is state
def replace(self, state): def replace(self, state):
if dict.__contains__(self, state.key): if state.key in self._dict:
existing = dict.__getitem__(self, state.key) existing = self._dict[state.key]
if existing is not state: if existing is not state:
self._manage_removed_state(existing) self._manage_removed_state(existing)
else: else:
return return
dict.__setitem__(self, state.key, state) self._dict[state.key] = state
self._manage_incoming_state(state) self._manage_incoming_state(state)
def add(self, state): def add(self, state):
if state.key in self: key = state.key
if dict.__getitem__(self, state.key) is not state: # inline of self.__contains__
raise AssertionError("A conflicting state is already " if key in self._dict:
"present in the identity map for key %r" try:
% (state.key, )) existing_state = self._dict[key]
else: if existing_state is not state:
dict.__setitem__(self, state.key, state) o = existing_state.obj()
self._manage_incoming_state(state) if o is not None:
raise sa_exc.InvalidRequestError(
def remove_key(self, key): "Can't attach instance "
state = dict.__getitem__(self, key) "%s; another instance with key %s is already "
self.remove(state) "present in this session." % (
orm_util.state_str(state), state.key))
def remove(self, state): else:
if dict.pop(self, state.key) is not state: return False
raise AssertionError("State %s is not present in this identity map" % state) except KeyError:
self._manage_removed_state(state) pass
self._dict[key] = state
def discard(self, state): self._manage_incoming_state(state)
if self.contains_state(state): return True
dict.__delitem__(self, state.key)
self._manage_removed_state(state) def _add_unpresent(self, state, key):
# inlined form of add() called by loading.py
self._dict[key] = state
state._instance_dict = self._wr
def get(self, key, default=None): def get(self, key, default=None):
state = dict.get(self, key, default) if key not in self._dict:
if state is default:
return default return default
state = self._dict[key]
o = state.obj() o = state.obj()
if o is None:
o = state._is_really_none()
if o is None: if o is None:
return default return default
return o return o
# Py2K
def items(self):
return list(self.iteritems())
def iteritems(self): def items(self):
for state in dict.itervalues(self): values = self.all_states()
# end Py2K result = []
# Py3K for state in values:
#def items(self):
# for state in dict.values(self):
value = state.obj() value = state.obj()
if value is not None: if value is not None:
yield state.key, value result.append((state.key, value))
return result
# Py2K
def values(self): def values(self):
return list(self.itervalues()) values = self.all_states()
result = []
for state in values:
value = state.obj()
if value is not None:
result.append(value)
def itervalues(self): return result
for state in dict.itervalues(self):
# end Py2K def __iter__(self):
# Py3K return iter(self.keys())
#def values(self):
# for state in dict.values(self): if util.py2k:
instance = state.obj()
if instance is not None: def iteritems(self):
yield instance return iter(self.items())
def itervalues(self):
return iter(self.values())
def all_states(self): def all_states(self):
# Py3K if util.py2k:
# return list(dict.values(self)) return self._dict.values()
else:
# Py2K return list(self._dict.values())
return dict.values(self)
# end Py2K def _fast_discard(self, state):
self._dict.pop(state.key, None)
def discard(self, state):
st = self._dict.pop(state.key, None)
if st:
assert st is state
self._manage_removed_state(state)
def safe_discard(self, state):
if state.key in self._dict:
st = self._dict[state.key]
if st is state:
self._dict.pop(state.key, None)
self._manage_removed_state(state)
def prune(self): def prune(self):
return 0 return 0
class StrongInstanceDict(IdentityMap): class StrongInstanceDict(IdentityMap):
"""A 'strong-referencing' version of the identity map.
.. deprecated 1.1::
The strong
reference identity map is legacy. See the
recipe at :ref:`session_referencing_behavior` for
an event-based approach to maintaining strong identity
references.
"""
if util.py2k:
def itervalues(self):
return self._dict.itervalues()
def iteritems(self):
return self._dict.iteritems()
def __iter__(self):
return iter(self.dict_)
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def get(self, key, default=None):
return self._dict.get(key, default)
def values(self):
return self._dict.values()
def items(self):
return self._dict.items()
def all_states(self): def all_states(self):
return [attributes.instance_state(o) for o in self.itervalues()] return [attributes.instance_state(o) for o in self.values()]
def contains_state(self, state): def contains_state(self, state):
return state.key in self and attributes.instance_state(self[state.key]) is state return (
state.key in self and
attributes.instance_state(self[state.key]) is state)
def replace(self, state): def replace(self, state):
if dict.__contains__(self, state.key): if state.key in self._dict:
existing = dict.__getitem__(self, state.key) existing = self._dict[state.key]
existing = attributes.instance_state(existing) existing = attributes.instance_state(existing)
if existing is not state: if existing is not state:
self._manage_removed_state(existing) self._manage_removed_state(existing)
else: else:
return return
dict.__setitem__(self, state.key, state.obj()) self._dict[state.key] = state.obj()
self._manage_incoming_state(state) self._manage_incoming_state(state)
def add(self, state): def add(self, state):
if state.key in self: if state.key in self:
if attributes.instance_state(dict.__getitem__(self, state.key)) is not state: if attributes.instance_state(self._dict[state.key]) is not state:
raise AssertionError("A conflicting state is already present in the identity map for key %r" % (state.key, )) raise sa_exc.InvalidRequestError(
"Can't attach instance "
"%s; another instance with key %s is already "
"present in this session." % (
orm_util.state_str(state), state.key))
return False
else: else:
dict.__setitem__(self, state.key, state.obj()) self._dict[state.key] = state.obj()
self._manage_incoming_state(state) self._manage_incoming_state(state)
return True
def remove(self, state):
if attributes.instance_state(dict.pop(self, state.key)) is not state: def _add_unpresent(self, state, key):
raise AssertionError("State %s is not present in this identity map" % state) # inlined form of add() called by loading.py
self._manage_removed_state(state) self._dict[key] = state.obj()
state._instance_dict = self._wr
def _fast_discard(self, state):
self._dict.pop(state.key, None)
def discard(self, state): def discard(self, state):
if self.contains_state(state): obj = self._dict.pop(state.key, None)
dict.__delitem__(self, state.key) if obj is not None:
self._manage_removed_state(state) self._manage_removed_state(state)
st = attributes.instance_state(obj)
def remove_key(self, key): assert st is state
state = attributes.instance_state(dict.__getitem__(self, key))
self.remove(state) def safe_discard(self, state):
if state.key in self._dict:
obj = self._dict[state.key]
st = attributes.instance_state(obj)
if st is state:
self._dict.pop(state.key, None)
self._manage_removed_state(state)
def prune(self): def prune(self):
"""prune unreferenced, non-dirty states.""" """prune unreferenced, non-dirty states."""
ref_count = len(self) ref_count = len(self)
dirty = [s.obj() for s in self.all_states() if s.modified] dirty = [s.obj() for s in self.all_states() if s.modified]
@ -244,8 +315,7 @@ class StrongInstanceDict(IdentityMap):
keepers = weakref.WeakValueDictionary() keepers = weakref.WeakValueDictionary()
keepers.update(self) keepers.update(self)
dict.clear(self) self._dict.clear()
dict.update(self, keepers) self._dict.update(keepers)
self.modified = bool(dirty) self.modified = bool(dirty)
return ref_count - len(self) return ref_count - len(self)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,96 +1,120 @@
# scoping.py # orm/scoping.py
# Copyright (C) the SQLAlchemy authors and contributors # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
import sqlalchemy.exceptions as sa_exc from .. import exc as sa_exc
from sqlalchemy.util import ScopedRegistry, ThreadLocalRegistry, \ from ..util import ScopedRegistry, ThreadLocalRegistry, warn
to_list, get_cls_kwargs, deprecated from . import class_mapper, exc as orm_exc
from sqlalchemy.orm import ( from .session import Session
EXT_CONTINUE, MapperExtension, class_mapper, object_session
)
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm.session import Session
__all__ = ['ScopedSession'] __all__ = ['scoped_session']
class ScopedSession(object): class scoped_session(object):
"""Provides thread-local management of Sessions. """Provides scoped management of :class:`.Session` objects.
Usage:: See :ref:`unitofwork_contextual` for a tutorial.
Session = scoped_session(sessionmaker(autoflush=True))
... use session normally.
""" """
session_factory = None
"""The `session_factory` provided to `__init__` is stored in this
attribute and may be accessed at a later time. This can be useful when
a new non-scoped :class:`.Session` or :class:`.Connection` to the
database is needed."""
def __init__(self, session_factory, scopefunc=None): def __init__(self, session_factory, scopefunc=None):
"""Construct a new :class:`.scoped_session`.
:param session_factory: a factory to create new :class:`.Session`
instances. This is usually, but not necessarily, an instance
of :class:`.sessionmaker`.
:param scopefunc: optional function which defines
the current scope. If not passed, the :class:`.scoped_session`
object assumes "thread-local" scope, and will use
a Python ``threading.local()`` in order to maintain the current
:class:`.Session`. If passed, the function should return
a hashable token; this token will be used as the key in a
dictionary in order to store and retrieve the current
:class:`.Session`.
"""
self.session_factory = session_factory self.session_factory = session_factory
if scopefunc: if scopefunc:
self.registry = ScopedRegistry(session_factory, scopefunc) self.registry = ScopedRegistry(session_factory, scopefunc)
else: else:
self.registry = ThreadLocalRegistry(session_factory) self.registry = ThreadLocalRegistry(session_factory)
self.extension = _ScopedExt(self)
def __call__(self, **kwargs): def __call__(self, **kw):
if kwargs: r"""Return the current :class:`.Session`, creating it
scope = kwargs.pop('scope', False) using the :attr:`.scoped_session.session_factory` if not present.
:param \**kw: Keyword arguments will be passed to the
:attr:`.scoped_session.session_factory` callable, if an existing
:class:`.Session` is not present. If the :class:`.Session` is present
and keyword arguments have been passed,
:exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
"""
if kw:
scope = kw.pop('scope', False)
if scope is not None: if scope is not None:
if self.registry.has(): if self.registry.has():
raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.") raise sa_exc.InvalidRequestError(
"Scoped session is already present; "
"no new arguments may be specified.")
else: else:
sess = self.session_factory(**kwargs) sess = self.session_factory(**kw)
self.registry.set(sess) self.registry.set(sess)
return sess return sess
else: else:
return self.session_factory(**kwargs) return self.session_factory(**kw)
else: else:
return self.registry() return self.registry()
def remove(self): def remove(self):
"""Dispose of the current contextual session.""" """Dispose of the current :class:`.Session`, if present.
This will first call :meth:`.Session.close` method
on the current :class:`.Session`, which releases any existing
transactional/connection resources still being held; transactions
specifically are rolled back. The :class:`.Session` is then
discarded. Upon next usage within the same scope,
the :class:`.scoped_session` will produce a new
:class:`.Session` object.
"""
if self.registry.has(): if self.registry.has():
self.registry().close() self.registry().close()
self.registry.clear() self.registry.clear()
@deprecated("Session.mapper is deprecated. " def configure(self, **kwargs):
"Please see http://www.sqlalchemy.org/trac/wiki/UsageRecipes/SessionAwareMapper " """reconfigure the :class:`.sessionmaker` used by this
"for information on how to replicate its behavior.") :class:`.scoped_session`.
def mapper(self, *args, **kwargs):
"""return a mapper() function which associates this ScopedSession with the Mapper.
DEPRECATED. See :meth:`.sessionmaker.configure`.
""" """
from sqlalchemy.orm import mapper if self.registry.has():
warn('At least one scoped session is already present. '
extension_args = dict((arg, kwargs.pop(arg)) ' configure() can not affect sessions that have '
for arg in get_cls_kwargs(_ScopedExt) 'already been created.')
if arg in kwargs)
kwargs['extension'] = extension = to_list(kwargs.get('extension', []))
if extension_args:
extension.append(self.extension.configure(**extension_args))
else:
extension.append(self.extension)
return mapper(*args, **kwargs)
def configure(self, **kwargs):
"""reconfigure the sessionmaker used by this ScopedSession."""
self.session_factory.configure(**kwargs) self.session_factory.configure(**kwargs)
def query_property(self, query_cls=None): def query_property(self, query_cls=None):
"""return a class property which produces a `Query` object against the """return a class property which produces a :class:`.Query` object
class when called. against the class and the current :class:`.Session` when called.
e.g.:: e.g.::
Session = scoped_session(sessionmaker()) Session = scoped_session(sessionmaker())
class MyClass(object): class MyClass(object):
@ -124,82 +148,37 @@ class ScopedSession(object):
return None return None
return query() return query()
ScopedSession = scoped_session
"""Old name for backwards compatibility."""
def instrument(name): def instrument(name):
def do(self, *args, **kwargs): def do(self, *args, **kwargs):
return getattr(self.registry(), name)(*args, **kwargs) return getattr(self.registry(), name)(*args, **kwargs)
return do return do
for meth in Session.public_methods: for meth in Session.public_methods:
setattr(ScopedSession, meth, instrument(meth)) setattr(scoped_session, meth, instrument(meth))
def makeprop(name): def makeprop(name):
def set(self, attr): def set(self, attr):
setattr(self.registry(), name, attr) setattr(self.registry(), name, attr)
def get(self): def get(self):
return getattr(self.registry(), name) return getattr(self.registry(), name)
return property(get, set) return property(get, set)
for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', 'is_active', 'autoflush'):
setattr(ScopedSession, prop, makeprop(prop)) for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map',
'is_active', 'autoflush', 'no_autoflush', 'info'):
setattr(scoped_session, prop, makeprop(prop))
def clslevel(name): def clslevel(name):
def do(cls, *args, **kwargs): def do(cls, *args, **kwargs):
return getattr(Session, name)(*args, **kwargs) return getattr(Session, name)(*args, **kwargs)
return classmethod(do) return classmethod(do)
for prop in ('close_all', 'object_session', 'identity_key'): for prop in ('close_all', 'object_session', 'identity_key'):
setattr(ScopedSession, prop, clslevel(prop)) setattr(scoped_session, prop, clslevel(prop))
class _ScopedExt(MapperExtension):
def __init__(self, context, validate=False, save_on_init=True):
self.context = context
self.validate = validate
self.save_on_init = save_on_init
self.set_kwargs_on_init = True
def validating(self):
return _ScopedExt(self.context, validate=True)
def configure(self, **kwargs):
return _ScopedExt(self.context, **kwargs)
def instrument_class(self, mapper, class_):
class query(object):
def __getattr__(s, key):
return getattr(self.context.registry().query(class_), key)
def __call__(s):
return self.context.registry().query(class_)
def __get__(self, instance, cls):
return self
if not 'query' in class_.__dict__:
class_.query = query()
if self.set_kwargs_on_init and class_.__init__ is object.__init__:
class_.__init__ = self._default__init__(mapper)
def _default__init__(ext, mapper):
def __init__(self, **kwargs):
for key, value in kwargs.iteritems():
if ext.validate:
if not mapper.get_property(key, resolve_synonyms=False,
raiseerr=False):
raise sa_exc.ArgumentError(
"Invalid __init__ argument: '%s'" % key)
setattr(self, key, value)
return __init__
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
if self.save_on_init:
session = kwargs.pop('_sa_session', None)
if session is None:
session = self.context.registry()
session._save_without_cascade(instance)
return EXT_CONTINUE
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
sess = object_session(instance)
if sess:
sess.expunge(instance)
return EXT_CONTINUE
def dispose_class(self, mapper, class_):
if hasattr(class_, 'query'):
delattr(class_, 'query')

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,98 +1,140 @@
# mapper/sync.py # orm/sync.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
"""private module containing functions used for copying data """private module containing functions used for copying data
between instances based on join conditions. between instances based on join conditions.
""" """
from sqlalchemy.orm import exc, util as mapperutil from . import exc, util as orm_util, attributes
def populate(source, source_mapper, dest, dest_mapper,
synchronize_pairs, uowcommit, flag_cascaded_pks):
source_dict = source.dict
dest_dict = dest.dict
def populate(source, source_mapper, dest, dest_mapper,
synchronize_pairs, uowcommit, passive_updates):
for l, r in synchronize_pairs: for l, r in synchronize_pairs:
try: try:
value = source_mapper._get_state_attr_by_column(source, l) # inline of source_mapper._get_state_attr_by_column
prop = source_mapper._columntoproperty[l]
value = source.manager[prop.key].impl.get(source, source_dict,
attributes.PASSIVE_OFF)
except exc.UnmappedColumnError: except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, dest_mapper, r) _raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
try: try:
dest_mapper._set_state_attr_by_column(dest, r, value) # inline of dest_mapper._set_state_attr_by_column
prop = dest_mapper._columntoproperty[r]
dest.manager[prop.key].impl.set(dest, dest_dict, value, None)
except exc.UnmappedColumnError: except exc.UnmappedColumnError:
_raise_col_to_prop(True, source_mapper, l, dest_mapper, r) _raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
# techically the "r.primary_key" check isn't # technically the "r.primary_key" check isn't
# needed here, but we check for this condition to limit # needed here, but we check for this condition to limit
# how often this logic is invoked for memory/performance # how often this logic is invoked for memory/performance
# reasons, since we only need this info for a primary key # reasons, since we only need this info for a primary key
# destination. # destination.
if l.primary_key and r.primary_key and \ if flag_cascaded_pks and l.primary_key and \
r.references(l) and passive_updates: r.primary_key and \
r.references(l):
uowcommit.attributes[("pk_cascaded", dest, r)] = True uowcommit.attributes[("pk_cascaded", dest, r)] = True
def bulk_populate_inherit_keys(
source_dict, source_mapper, synchronize_pairs):
# a simplified version of populate() used by bulk insert mode
for l, r in synchronize_pairs:
try:
prop = source_mapper._columntoproperty[l]
value = source_dict[prop.key]
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, source_mapper, r)
try:
prop = source_mapper._columntoproperty[r]
source_dict[prop.key] = value
except exc.UnmappedColumnError:
_raise_col_to_prop(True, source_mapper, l, source_mapper, r)
def clear(dest, dest_mapper, synchronize_pairs): def clear(dest, dest_mapper, synchronize_pairs):
for l, r in synchronize_pairs: for l, r in synchronize_pairs:
if r.primary_key: if r.primary_key and \
dest_mapper._get_state_attr_by_column(
dest, dest.dict, r) not in orm_util._none_set:
raise AssertionError( raise AssertionError(
"Dependency rule tried to blank-out primary key " "Dependency rule tried to blank-out primary key "
"column '%s' on instance '%s'" % "column '%s' on instance '%s'" %
(r, mapperutil.state_str(dest)) (r, orm_util.state_str(dest))
) )
try: try:
dest_mapper._set_state_attr_by_column(dest, r, None) dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None)
except exc.UnmappedColumnError: except exc.UnmappedColumnError:
_raise_col_to_prop(True, None, l, dest_mapper, r) _raise_col_to_prop(True, None, l, dest_mapper, r)
def update(source, source_mapper, dest, old_prefix, synchronize_pairs): def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
for l, r in synchronize_pairs: for l, r in synchronize_pairs:
try: try:
oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l) oldvalue = source_mapper._get_committed_attr_by_column(
value = source_mapper._get_state_attr_by_column(source, l) source.obj(), l)
value = source_mapper._get_state_attr_by_column(
source, source.dict, l, passive=attributes.PASSIVE_OFF)
except exc.UnmappedColumnError: except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r) _raise_col_to_prop(False, source_mapper, l, None, r)
dest[r.key] = value dest[r.key] = value
dest[old_prefix + r.key] = oldvalue dest[old_prefix + r.key] = oldvalue
def populate_dict(source, source_mapper, dict_, synchronize_pairs): def populate_dict(source, source_mapper, dict_, synchronize_pairs):
for l, r in synchronize_pairs: for l, r in synchronize_pairs:
try: try:
value = source_mapper._get_state_attr_by_column(source, l) value = source_mapper._get_state_attr_by_column(
source, source.dict, l, passive=attributes.PASSIVE_OFF)
except exc.UnmappedColumnError: except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r) _raise_col_to_prop(False, source_mapper, l, None, r)
dict_[r.key] = value dict_[r.key] = value
def source_modified(uowcommit, source, source_mapper, synchronize_pairs): def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
"""return true if the source object has changes from an old to a """return true if the source object has changes from an old to a
new value on the given synchronize pairs new value on the given synchronize pairs
""" """
for l, r in synchronize_pairs: for l, r in synchronize_pairs:
try: try:
prop = source_mapper._get_col_to_prop(l) prop = source_mapper._columntoproperty[l]
except exc.UnmappedColumnError: except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r) _raise_col_to_prop(False, source_mapper, l, None, r)
history = uowcommit.get_attribute_history(source, prop.key, passive=True) history = uowcommit.get_attribute_history(
if len(history.deleted): source, prop.key, attributes.PASSIVE_NO_INITIALIZE)
if bool(history.deleted):
return True return True
else: else:
return False return False
def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column):
def _raise_col_to_prop(isdest, source_mapper, source_column,
dest_mapper, dest_column):
if isdest: if isdest:
raise exc.UnmappedColumnError( raise exc.UnmappedColumnError(
"Can't execute sync rule for destination column '%s'; " "Can't execute sync rule for "
"mapper '%s' does not map this column. Try using an explicit" "destination column '%s'; mapper '%s' does not map "
" `foreign_keys` collection which does not include this column " "this column. Try using an explicit `foreign_keys` "
"(or use a viewonly=True relation)." % (dest_column, source_mapper) "collection which does not include this column (or use "
) "a viewonly=True relation)." % (dest_column, dest_mapper))
else: else:
raise exc.UnmappedColumnError( raise exc.UnmappedColumnError(
"Can't execute sync rule for source column '%s'; mapper '%s' " "Can't execute sync rule for "
"does not map this column. Try using an explicit `foreign_keys`" "source column '%s'; mapper '%s' does not map this "
" collection which does not include destination column '%s' (or " "column. Try using an explicit `foreign_keys` "
"use a viewonly=True relation)." % "collection which does not include destination column "
(source_column, source_mapper, dest_column) "'%s' (or use a viewonly=True relation)." %
) (source_column, source_mapper, dest_column))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,12 @@
# processors.py # sqlalchemy/processors.py
# Copyright (C) 2010-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com # Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
# #
# This module is part of SQLAlchemy and is released under # This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php # the MIT License: http://www.opensource.org/licenses/mit-license.php
"""defines generic type conversion functions, as used in bind and result """defines generic type conversion functions, as used in bind and result
processors. processors.
They all share one common characteristic: None is passed through unchanged. They all share one common characteristic: None is passed through unchanged.
@ -14,41 +16,47 @@ They all share one common characteristic: None is passed through unchanged.
import codecs import codecs
import re import re
import datetime import datetime
from . import util
def str_to_datetime_processor_factory(regexp, type_): def str_to_datetime_processor_factory(regexp, type_):
rmatch = regexp.match rmatch = regexp.match
# Even on python2.6 datetime.strptime is both slower than this code # Even on python2.6 datetime.strptime is both slower than this code
# and it does not support microseconds. # and it does not support microseconds.
has_named_groups = bool(regexp.groupindex)
def process(value): def process(value):
if value is None: if value is None:
return None return None
else: else:
return type_(*map(int, rmatch(value).groups(0))) try:
m = rmatch(value)
except TypeError:
raise ValueError("Couldn't parse %s string '%r' "
"- value is not a string." %
(type_.__name__, value))
if m is None:
raise ValueError("Couldn't parse %s string: "
"'%s'" % (type_.__name__, value))
if has_named_groups:
groups = m.groupdict(0)
return type_(**dict(list(zip(
iter(groups.keys()),
list(map(int, iter(groups.values())))
))))
else:
return type_(*list(map(int, m.groups(0))))
return process return process
try:
from sqlalchemy.cprocessors import UnicodeResultProcessor, \
DecimalResultProcessor, \
to_float, to_str, int_to_boolean, \
str_to_datetime, str_to_time, \
str_to_date
def to_unicode_processor_factory(encoding, errors=None): def boolean_to_int(value):
# this is cumbersome but it would be even more so on the C side if value is None:
if errors is not None: return None
return UnicodeResultProcessor(encoding, errors).process else:
else: return int(bool(value))
return UnicodeResultProcessor(encoding).process
def to_decimal_processor_factory(target_class, scale=10):
# Note that the scale argument is not taken into account for integer
# values in the C implementation while it is in the Python one.
# For example, the Python implementation might return
# Decimal('5.00000') whereas the C implementation will
# return Decimal('5'). These are equivalent of course.
return DecimalResultProcessor(target_class, "%%.%df" % scale).process
except ImportError:
def py_fallback():
def to_unicode_processor_factory(encoding, errors=None): def to_unicode_processor_factory(encoding, errors=None):
decoder = codecs.getdecoder(encoding) decoder = codecs.getdecoder(encoding)
@ -62,7 +70,22 @@ except ImportError:
return decoder(value, errors)[0] return decoder(value, errors)[0]
return process return process
def to_decimal_processor_factory(target_class, scale=10): def to_conditional_unicode_processor_factory(encoding, errors=None):
decoder = codecs.getdecoder(encoding)
def process(value):
if value is None:
return None
elif isinstance(value, util.text_type):
return value
else:
# decoder returns a tuple: (value, len). Simply dropping the
# len part is safe: it is done that way in the normal
# 'xx'.decode(encoding) code path.
return decoder(value, errors)[0]
return process
def to_decimal_processor_factory(target_class, scale):
fstring = "%%.%df" % scale fstring = "%%.%df" % scale
def process(value): def process(value):
@ -88,14 +111,45 @@ except ImportError:
if value is None: if value is None:
return None return None
else: else:
return value and True or False return bool(value)
DATETIME_RE = re.compile("(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?") DATETIME_RE = re.compile(
TIME_RE = re.compile("(\d+):(\d+):(\d+)(?:\.(\d+))?") r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?")
DATE_RE = re.compile("(\d+)-(\d+)-(\d+)") TIME_RE = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
DATE_RE = re.compile(r"(\d+)-(\d+)-(\d+)")
str_to_datetime = str_to_datetime_processor_factory(DATETIME_RE, str_to_datetime = str_to_datetime_processor_factory(DATETIME_RE,
datetime.datetime) datetime.datetime)
str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time) str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time)
str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date) str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date)
return locals()
try:
from sqlalchemy.cprocessors import UnicodeResultProcessor, \
DecimalResultProcessor, \
to_float, to_str, int_to_boolean, \
str_to_datetime, str_to_time, \
str_to_date
def to_unicode_processor_factory(encoding, errors=None):
if errors is not None:
return UnicodeResultProcessor(encoding, errors).process
else:
return UnicodeResultProcessor(encoding).process
def to_conditional_unicode_processor_factory(encoding, errors=None):
if errors is not None:
return UnicodeResultProcessor(encoding, errors).conditional_process
else:
return UnicodeResultProcessor(encoding).conditional_process
def to_decimal_processor_factory(target_class, scale):
# Note that the scale argument is not taken into account for integer
# values in the C implementation while it is in the Python one.
# For example, the Python implementation might return
# Decimal('5.00000') whereas the C implementation will
# return Decimal('5'). These are equivalent of course.
return DecimalResultProcessor(target_class, "%%.%df" % scale).process
except ImportError:
globals().update(py_fallback())

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,11 @@
from sqlalchemy.sql.expression import ( # sql/__init__.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .expression import (
Alias, Alias,
ClauseElement, ClauseElement,
ColumnCollection, ColumnCollection,
@ -11,9 +18,12 @@ from sqlalchemy.sql.expression import (
Select, Select,
Selectable, Selectable,
TableClause, TableClause,
TableSample,
Update, Update,
alias, alias,
and_, and_,
any_,
all_,
asc, asc,
between, between,
bindparam, bindparam,
@ -28,12 +38,16 @@ from sqlalchemy.sql.expression import (
except_all, except_all,
exists, exists,
extract, extract,
false,
False_,
func, func,
funcfilter,
insert, insert,
intersect, intersect,
intersect_all, intersect_all,
join, join,
label, label,
lateral,
literal, literal,
literal_column, literal_column,
modifier, modifier,
@ -42,17 +56,43 @@ from sqlalchemy.sql.expression import (
or_, or_,
outerjoin, outerjoin,
outparam, outparam,
over,
select, select,
subquery, subquery,
table, table,
tablesample,
text, text,
true,
True_,
tuple_, tuple_,
type_coerce,
union, union,
union_all, union_all,
update, update,
) within_group
)
from sqlalchemy.sql.visitors import ClauseVisitor from .visitors import ClauseVisitor
__tmp = locals().keys()
__all__ = sorted([i for i in __tmp if not i.startswith('__')]) def __go(lcls):
global __all__
from .. import util as _sa_util
import inspect as _inspect
__all__ = sorted(name for name, obj in lcls.items()
if not (name.startswith('_') or _inspect.ismodule(obj)))
from .annotation import _prepare_annotations, Annotated
from .elements import AnnotatedColumnElement, ClauseList
from .selectable import AnnotatedFromClause
_prepare_annotations(ColumnElement, AnnotatedColumnElement)
_prepare_annotations(FromClause, AnnotatedFromClause)
_prepare_annotations(ClauseList, Annotated)
_sa_util.dependencies.resolve_all("sqlalchemy.sql")
from . import naming
__go(locals())

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,104 +1,813 @@
from sqlalchemy import types as sqltypes # sql/functions.py
from sqlalchemy.sql.expression import ( # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
ClauseList, Function, _literal_as_binds, text, _type_from_args # <see AUTHORS file>
) #
from sqlalchemy.sql import operators # This module is part of SQLAlchemy and is released under
from sqlalchemy.sql.visitors import VisitableType # the MIT License: http://www.opensource.org/licenses/mit-license.php
"""SQL function API, factories, and built-in functions.
"""
from . import sqltypes, schema
from .base import Executable, ColumnCollection
from .elements import ClauseList, Cast, Extract, _literal_as_binds, \
literal_column, _type_from_args, ColumnElement, _clone,\
Over, BindParameter, FunctionFilter, Grouping, WithinGroup
from .selectable import FromClause, Select, Alias
from . import util as sqlutil
from . import operators
from .visitors import VisitableType
from .. import util
from . import annotation
_registry = util.defaultdict(dict)
def register_function(identifier, fn, package="_default"):
"""Associate a callable with a particular func. name.
This is normally called by _GenericMeta, but is also
available by itself so that a non-Function construct
can be associated with the :data:`.func` accessor (i.e.
CAST, EXTRACT).
"""
reg = _registry[package]
reg[identifier] = fn
class FunctionElement(Executable, ColumnElement, FromClause):
"""Base for SQL function-oriented constructs.
.. seealso::
:class:`.Function` - named SQL function.
:data:`.func` - namespace which produces registered or ad-hoc
:class:`.Function` instances.
:class:`.GenericFunction` - allows creation of registered function
types.
"""
packagenames = ()
def __init__(self, *clauses, **kwargs):
"""Construct a :class:`.FunctionElement`.
"""
args = [_literal_as_binds(c, self.name) for c in clauses]
self.clause_expr = ClauseList(
operator=operators.comma_op,
group_contents=True, *args).\
self_group()
def _execute_on_connection(self, connection, multiparams, params):
return connection._execute_function(self, multiparams, params)
@property
def columns(self):
"""The set of columns exported by this :class:`.FunctionElement`.
Function objects currently have no result column names built in;
this method returns a single-element column collection with
an anonymously named column.
An interim approach to providing named columns for a function
as a FROM clause is to build a :func:`.select` with the
desired columns::
from sqlalchemy.sql import column
stmt = select([column('x'), column('y')]).\
select_from(func.myfunction())
"""
return ColumnCollection(self.label(None))
@util.memoized_property
def clauses(self):
"""Return the underlying :class:`.ClauseList` which contains
the arguments for this :class:`.FunctionElement`.
"""
return self.clause_expr.element
def over(self, partition_by=None, order_by=None, rows=None, range_=None):
"""Produce an OVER clause against this function.
Used against aggregate or so-called "window" functions,
for database backends that support window functions.
The expression::
func.row_number().over(order_by='x')
is shorthand for::
from sqlalchemy import over
over(func.row_number(), order_by='x')
See :func:`~.expression.over` for a full description.
.. versionadded:: 0.7
"""
return Over(
self,
partition_by=partition_by,
order_by=order_by,
rows=rows,
range_=range_
)
def within_group(self, *order_by):
"""Produce a WITHIN GROUP (ORDER BY expr) clause against this function.
Used against so-called "ordered set aggregate" and "hypothetical
set aggregate" functions, including :class:`.percentile_cont`,
:class:`.rank`, :class:`.dense_rank`, etc.
See :func:`~.expression.within_group` for a full description.
.. versionadded:: 1.1
"""
return WithinGroup(self, *order_by)
def filter(self, *criterion):
"""Produce a FILTER clause against this function.
Used against aggregate and window functions,
for database backends that support the "FILTER" clause.
The expression::
func.count(1).filter(True)
is shorthand for::
from sqlalchemy import funcfilter
funcfilter(func.count(1), True)
.. versionadded:: 1.0.0
.. seealso::
:class:`.FunctionFilter`
:func:`.funcfilter`
"""
if not criterion:
return self
return FunctionFilter(self, *criterion)
@property
def _from_objects(self):
return self.clauses._from_objects
def get_children(self, **kwargs):
return self.clause_expr,
def _copy_internals(self, clone=_clone, **kw):
self.clause_expr = clone(self.clause_expr, **kw)
self._reset_exported()
FunctionElement.clauses._reset(self)
def within_group_type(self, within_group):
"""For types that define their return type as based on the criteria
within a WITHIN GROUP (ORDER BY) expression, called by the
:class:`.WithinGroup` construct.
Returns None by default, in which case the function's normal ``.type``
is used.
"""
return None
def alias(self, name=None, flat=False):
r"""Produce a :class:`.Alias` construct against this
:class:`.FunctionElement`.
This construct wraps the function in a named alias which
is suitable for the FROM clause, in the style accepted for example
by PostgreSQL.
e.g.::
from sqlalchemy.sql import column
stmt = select([column('data_view')]).\
select_from(SomeTable).\
select_from(func.unnest(SomeTable.data).alias('data_view')
)
Would produce:
.. sourcecode:: sql
SELECT data_view
FROM sometable, unnest(sometable.data) AS data_view
.. versionadded:: 0.9.8 The :meth:`.FunctionElement.alias` method
is now supported. Previously, this method's behavior was
undefined and did not behave consistently across versions.
"""
return Alias(self, name)
def select(self):
"""Produce a :func:`~.expression.select` construct
against this :class:`.FunctionElement`.
This is shorthand for::
s = select([function_element])
"""
s = Select([self])
if self._execution_options:
s = s.execution_options(**self._execution_options)
return s
def scalar(self):
"""Execute this :class:`.FunctionElement` against an embedded
'bind' and return a scalar value.
This first calls :meth:`~.FunctionElement.select` to
produce a SELECT construct.
Note that :class:`.FunctionElement` can be passed to
the :meth:`.Connectable.scalar` method of :class:`.Connection`
or :class:`.Engine`.
"""
return self.select().execute().scalar()
def execute(self):
"""Execute this :class:`.FunctionElement` against an embedded
'bind'.
This first calls :meth:`~.FunctionElement.select` to
produce a SELECT construct.
Note that :class:`.FunctionElement` can be passed to
the :meth:`.Connectable.execute` method of :class:`.Connection`
or :class:`.Engine`.
"""
return self.select().execute()
def _bind_param(self, operator, obj, type_=None):
return BindParameter(None, obj, _compared_to_operator=operator,
_compared_to_type=self.type, unique=True,
type_=type_)
def self_group(self, against=None):
# for the moment, we are parenthesizing all array-returning
# expressions against getitem. This may need to be made
# more portable if in the future we support other DBs
# besides postgresql.
if against is operators.getitem and \
isinstance(self.type, sqltypes.ARRAY):
return Grouping(self)
else:
return super(FunctionElement, self).self_group(against=against)
class _FunctionGenerator(object):
"""Generate :class:`.Function` objects based on getattr calls."""
def __init__(self, **opts):
self.__names = []
self.opts = opts
def __getattr__(self, name):
# passthru __ attributes; fixes pydoc
if name.startswith('__'):
try:
return self.__dict__[name]
except KeyError:
raise AttributeError(name)
elif name.endswith('_'):
name = name[0:-1]
f = _FunctionGenerator(**self.opts)
f.__names = list(self.__names) + [name]
return f
def __call__(self, *c, **kwargs):
o = self.opts.copy()
o.update(kwargs)
tokens = len(self.__names)
if tokens == 2:
package, fname = self.__names
elif tokens == 1:
package, fname = "_default", self.__names[0]
else:
package = None
if package is not None:
func = _registry[package].get(fname)
if func is not None:
return func(*c, **o)
return Function(self.__names[-1],
packagenames=self.__names[0:-1], *c, **o)
func = _FunctionGenerator()
"""Generate SQL function expressions.
:data:`.func` is a special object instance which generates SQL
functions based on name-based attributes, e.g.::
>>> print(func.count(1))
count(:param_1)
The element is a column-oriented SQL element like any other, and is
used in that way::
>>> print(select([func.count(table.c.id)]))
SELECT count(sometable.id) FROM sometable
Any name can be given to :data:`.func`. If the function name is unknown to
SQLAlchemy, it will be rendered exactly as is. For common SQL functions
which SQLAlchemy is aware of, the name may be interpreted as a *generic
function* which will be compiled appropriately to the target database::
>>> print(func.current_timestamp())
CURRENT_TIMESTAMP
To call functions which are present in dot-separated packages,
specify them in the same manner::
>>> print(func.stats.yield_curve(5, 10))
stats.yield_curve(:yield_curve_1, :yield_curve_2)
SQLAlchemy can be made aware of the return type of functions to enable
type-specific lexical and result-based behavior. For example, to ensure
that a string-based function returns a Unicode value and is similarly
treated as a string in expressions, specify
:class:`~sqlalchemy.types.Unicode` as the type:
>>> print(func.my_string(u'hi', type_=Unicode) + ' ' +
... func.my_string(u'there', type_=Unicode))
my_string(:my_string_1) || :my_string_2 || my_string(:my_string_3)
The object returned by a :data:`.func` call is usually an instance of
:class:`.Function`.
This object meets the "column" interface, including comparison and labeling
functions. The object can also be passed the :meth:`~.Connectable.execute`
method of a :class:`.Connection` or :class:`.Engine`, where it will be
wrapped inside of a SELECT statement first::
print(connection.execute(func.current_timestamp()).scalar())
In a few exception cases, the :data:`.func` accessor
will redirect a name to a built-in expression such as :func:`.cast`
or :func:`.extract`, as these names have well-known meaning
but are not exactly the same as "functions" from a SQLAlchemy
perspective.
.. versionadded:: 0.8 :data:`.func` can return non-function expression
constructs for common quasi-functional names like :func:`.cast`
and :func:`.extract`.
Functions which are interpreted as "generic" functions know how to
calculate their return type automatically. For a listing of known generic
functions, see :ref:`generic_functions`.
.. note::
The :data:`.func` construct has only limited support for calling
standalone "stored procedures", especially those with special
parameterization concerns.
See the section :ref:`stored_procedures` for details on how to use
the DBAPI-level ``callproc()`` method for fully traditional stored
procedures.
"""
modifier = _FunctionGenerator(group=False)
class Function(FunctionElement):
"""Describe a named SQL function.
See the superclass :class:`.FunctionElement` for a description
of public methods.
.. seealso::
:data:`.func` - namespace which produces registered or ad-hoc
:class:`.Function` instances.
:class:`.GenericFunction` - allows creation of registered function
types.
"""
__visit_name__ = 'function'
def __init__(self, name, *clauses, **kw):
"""Construct a :class:`.Function`.
The :data:`.func` construct is normally used to construct
new :class:`.Function` instances.
"""
self.packagenames = kw.pop('packagenames', None) or []
self.name = name
self._bind = kw.get('bind', None)
self.type = sqltypes.to_instance(kw.get('type_', None))
FunctionElement.__init__(self, *clauses, **kw)
def _bind_param(self, operator, obj, type_=None):
return BindParameter(self.name, obj,
_compared_to_operator=operator,
_compared_to_type=self.type,
type_=type_,
unique=True)
class _GenericMeta(VisitableType): class _GenericMeta(VisitableType):
def __call__(self, *args, **kwargs): def __init__(cls, clsname, bases, clsdict):
args = [_literal_as_binds(c) for c in args] if annotation.Annotated not in cls.__mro__:
return type.__call__(self, *args, **kwargs) cls.name = name = clsdict.get('name', clsname)
cls.identifier = identifier = clsdict.get('identifier', name)
package = clsdict.pop('package', '_default')
# legacy
if '__return_type__' in clsdict:
cls.type = clsdict['__return_type__']
register_function(identifier, cls, package)
super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
class GenericFunction(Function):
__metaclass__ = _GenericMeta
def __init__(self, type_=None, args=(), **kwargs): class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
"""Define a 'generic' function.
A generic function is a pre-established :class:`.Function`
class that is instantiated automatically when called
by name from the :data:`.func` attribute. Note that
calling any name from :data:`.func` has the effect that
a new :class:`.Function` instance is created automatically,
given that name. The primary use case for defining
a :class:`.GenericFunction` class is so that a function
of a particular name may be given a fixed return type.
It can also include custom argument parsing schemes as well
as additional methods.
Subclasses of :class:`.GenericFunction` are automatically
registered under the name of the class. For
example, a user-defined function ``as_utc()`` would
be available immediately::
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy.types import DateTime
class as_utc(GenericFunction):
type = DateTime
print select([func.as_utc()])
User-defined generic functions can be organized into
packages by specifying the "package" attribute when defining
:class:`.GenericFunction`. Third party libraries
containing many functions may want to use this in order
to avoid name conflicts with other systems. For example,
if our ``as_utc()`` function were part of a package
"time"::
class as_utc(GenericFunction):
type = DateTime
package = "time"
The above function would be available from :data:`.func`
using the package name ``time``::
print select([func.time.as_utc()])
A final option is to allow the function to be accessed
from one name in :data:`.func` but to render as a different name.
The ``identifier`` attribute will override the name used to
access the function as loaded from :data:`.func`, but will retain
the usage of ``name`` as the rendered name::
class GeoBuffer(GenericFunction):
type = Geometry
package = "geo"
name = "ST_Buffer"
identifier = "buffer"
The above function will render as follows::
>>> print func.geo.buffer()
ST_Buffer()
.. versionadded:: 0.8 :class:`.GenericFunction` now supports
automatic registration of new functions as well as package
and custom naming support.
.. versionchanged:: 0.8 The attribute name ``type`` is used
to specify the function's return type at the class level.
Previously, the name ``__return_type__`` was used. This
name is still recognized for backwards-compatibility.
"""
coerce_arguments = True
def __init__(self, *args, **kwargs):
parsed_args = kwargs.pop('_parsed_args', None)
if parsed_args is None:
parsed_args = [_literal_as_binds(c, self.name) for c in args]
self.packagenames = [] self.packagenames = []
self.name = self.__class__.__name__
self._bind = kwargs.get('bind', None) self._bind = kwargs.get('bind', None)
self.clause_expr = ClauseList( self.clause_expr = ClauseList(
operator=operators.comma_op, operator=operators.comma_op,
group_contents=True, *args).self_group() group_contents=True, *parsed_args).self_group()
self.type = sqltypes.to_instance( self.type = sqltypes.to_instance(
type_ or getattr(self, '__return_type__', None)) kwargs.pop("type_", None) or getattr(self, 'type', None))
register_function("cast", Cast)
register_function("extract", Extract)
class next_value(GenericFunction):
"""Represent the 'next value', given a :class:`.Sequence`
as its single argument.
Compiles into the appropriate function on each backend,
or will raise NotImplementedError if used on a backend
that does not provide support for sequences.
"""
type = sqltypes.Integer()
name = "next_value"
def __init__(self, seq, **kw):
assert isinstance(seq, schema.Sequence), \
"next_value() accepts a Sequence object as input."
self._bind = kw.get('bind', None)
self.sequence = seq
@property
def _from_objects(self):
return []
class AnsiFunction(GenericFunction): class AnsiFunction(GenericFunction):
def __init__(self, **kwargs): def __init__(self, **kwargs):
GenericFunction.__init__(self, **kwargs) GenericFunction.__init__(self, **kwargs)
class ReturnTypeFromArgs(GenericFunction): class ReturnTypeFromArgs(GenericFunction):
"""Define a function whose return type is the same as its arguments.""" """Define a function whose return type is the same as its arguments."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
args = [_literal_as_binds(c, self.name) for c in args]
kwargs.setdefault('type_', _type_from_args(args)) kwargs.setdefault('type_', _type_from_args(args))
GenericFunction.__init__(self, args=args, **kwargs) kwargs['_parsed_args'] = args
super(ReturnTypeFromArgs, self).__init__(*args, **kwargs)
class coalesce(ReturnTypeFromArgs): class coalesce(ReturnTypeFromArgs):
pass pass
class max(ReturnTypeFromArgs): class max(ReturnTypeFromArgs):
pass pass
class min(ReturnTypeFromArgs): class min(ReturnTypeFromArgs):
pass pass
class sum(ReturnTypeFromArgs): class sum(ReturnTypeFromArgs):
pass pass
class now(GenericFunction): class now(GenericFunction):
__return_type__ = sqltypes.DateTime type = sqltypes.DateTime
class concat(GenericFunction): class concat(GenericFunction):
__return_type__ = sqltypes.String type = sqltypes.String
def __init__(self, *args, **kwargs):
GenericFunction.__init__(self, args=args, **kwargs)
class char_length(GenericFunction): class char_length(GenericFunction):
__return_type__ = sqltypes.Integer type = sqltypes.Integer
def __init__(self, arg, **kwargs): def __init__(self, arg, **kwargs):
GenericFunction.__init__(self, args=[arg], **kwargs) GenericFunction.__init__(self, arg, **kwargs)
class random(GenericFunction): class random(GenericFunction):
def __init__(self, *args, **kwargs): pass
kwargs.setdefault('type_', None)
GenericFunction.__init__(self, args=args, **kwargs)
class count(GenericFunction): class count(GenericFunction):
"""The ANSI COUNT aggregate function. With no arguments, emits COUNT \*.""" r"""The ANSI COUNT aggregate function. With no arguments,
emits COUNT \*.
__return_type__ = sqltypes.Integer """
type = sqltypes.Integer
def __init__(self, expression=None, **kwargs): def __init__(self, expression=None, **kwargs):
if expression is None: if expression is None:
expression = text('*') expression = literal_column('*')
GenericFunction.__init__(self, args=(expression,), **kwargs) super(count, self).__init__(expression, **kwargs)
class current_date(AnsiFunction): class current_date(AnsiFunction):
__return_type__ = sqltypes.Date type = sqltypes.Date
class current_time(AnsiFunction): class current_time(AnsiFunction):
__return_type__ = sqltypes.Time type = sqltypes.Time
class current_timestamp(AnsiFunction): class current_timestamp(AnsiFunction):
__return_type__ = sqltypes.DateTime type = sqltypes.DateTime
class current_user(AnsiFunction): class current_user(AnsiFunction):
__return_type__ = sqltypes.String type = sqltypes.String
class localtime(AnsiFunction): class localtime(AnsiFunction):
__return_type__ = sqltypes.DateTime type = sqltypes.DateTime
class localtimestamp(AnsiFunction): class localtimestamp(AnsiFunction):
__return_type__ = sqltypes.DateTime type = sqltypes.DateTime
class session_user(AnsiFunction): class session_user(AnsiFunction):
__return_type__ = sqltypes.String type = sqltypes.String
class sysdate(AnsiFunction): class sysdate(AnsiFunction):
__return_type__ = sqltypes.DateTime type = sqltypes.DateTime
class user(AnsiFunction): class user(AnsiFunction):
__return_type__ = sqltypes.String type = sqltypes.String
class array_agg(GenericFunction):
"""support for the ARRAY_AGG function.
The ``func.array_agg(expr)`` construct returns an expression of
type :class:`.types.ARRAY`.
e.g.::
stmt = select([func.array_agg(table.c.values)[2:5]])
.. versionadded:: 1.1
.. seealso::
:func:`.postgresql.array_agg` - PostgreSQL-specific version that
returns :class:`.postgresql.ARRAY`, which has PG-specific operators added.
"""
type = sqltypes.ARRAY
def __init__(self, *args, **kwargs):
args = [_literal_as_binds(c) for c in args]
kwargs.setdefault('type_', self.type(_type_from_args(args)))
kwargs['_parsed_args'] = args
super(array_agg, self).__init__(*args, **kwargs)
class OrderedSetAgg(GenericFunction):
"""Define a function where the return type is based on the sort
expression type as defined by the expression passed to the
:meth:`.FunctionElement.within_group` method."""
array_for_multi_clause = False
def within_group_type(self, within_group):
func_clauses = self.clause_expr.element
order_by = sqlutil.unwrap_order_by(within_group.order_by)
if self.array_for_multi_clause and len(func_clauses.clauses) > 1:
return sqltypes.ARRAY(order_by[0].type)
else:
return order_by[0].type
class mode(OrderedSetAgg):
"""implement the ``mode`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is the same as the sort expression.
.. versionadded:: 1.1
"""
class percentile_cont(OrderedSetAgg):
"""implement the ``percentile_cont`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is the same as the sort expression,
or if the arguments are an array, an :class:`.types.ARRAY` of the sort
expression's type.
.. versionadded:: 1.1
"""
array_for_multi_clause = True
class percentile_disc(OrderedSetAgg):
"""implement the ``percentile_disc`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is the same as the sort expression,
or if the arguments are an array, an :class:`.types.ARRAY` of the sort
expression's type.
.. versionadded:: 1.1
"""
array_for_multi_clause = True
class rank(GenericFunction):
"""Implement the ``rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is :class:`.Integer`.
.. versionadded:: 1.1
"""
type = sqltypes.Integer()
class dense_rank(GenericFunction):
"""Implement the ``dense_rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is :class:`.Integer`.
.. versionadded:: 1.1
"""
type = sqltypes.Integer()
class percent_rank(GenericFunction):
"""Implement the ``percent_rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is :class:`.Numeric`.
.. versionadded:: 1.1
"""
type = sqltypes.Numeric()
class cume_dist(GenericFunction):
"""Implement the ``cume_dist`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
modifier to supply a sort expression to operate upon.
The return type of this function is :class:`.Numeric`.
.. versionadded:: 1.1
"""
type = sqltypes.Numeric()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,90 +1,137 @@
# sql/visitors.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Visitor/traversal interface and library functions. """Visitor/traversal interface and library functions.
SQLAlchemy schema and expression constructs rely on a Python-centric SQLAlchemy schema and expression constructs rely on a Python-centric
version of the classic "visitor" pattern as the primary way in which version of the classic "visitor" pattern as the primary way in which
they apply functionality. The most common use of this pattern they apply functionality. The most common use of this pattern
is statement compilation, where individual expression classes match is statement compilation, where individual expression classes match
up to rendering methods that produce a string result. Beyond this, up to rendering methods that produce a string result. Beyond this,
the visitor system is also used to inspect expressions for various the visitor system is also used to inspect expressions for various
information and patterns, as well as for usage in information and patterns, as well as for usage in
some kinds of expression transformation. Other kinds of transformation some kinds of expression transformation. Other kinds of transformation
use a non-visitor traversal system. use a non-visitor traversal system.
For many examples of how the visit system is used, see the For many examples of how the visit system is used, see the
sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules. sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules.
For an introduction to clause adaption, see For an introduction to clause adaption, see
http://techspot.zzzeek.org/?p=19 . http://techspot.zzzeek.org/2008/01/23/expression-transformations/
""" """
from collections import deque from collections import deque
import re from .. import util
from sqlalchemy import util
import operator import operator
from .. import exc
__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
'iterate_depthfirst', 'traverse_using', 'traverse',
'traverse_depthfirst',
'cloned_traverse', 'replacement_traverse']
__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
'iterate_depthfirst', 'traverse_using', 'traverse',
'cloned_traverse', 'replacement_traverse']
class VisitableType(type): class VisitableType(type):
"""Metaclass which checks for a `__visit_name__` attribute and """Metaclass which assigns a `_compiler_dispatch` method to classes
applies `_compiler_dispatch` method to classes. having a `__visit_name__` attribute.
The _compiler_dispatch attribute becomes an instance method which
looks approximately like the following::
def _compiler_dispatch (self, visitor, **kw):
'''Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.'''
visit_attr = 'visit_%s' % self.__visit_name__
return getattr(visitor, visit_attr)(self, **kw)
Classes having no __visit_name__ attribute will remain unaffected.
""" """
def __init__(cls, clsname, bases, clsdict): def __init__(cls, clsname, bases, clsdict):
if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'): if clsname != 'Visitable' and \
super(VisitableType, cls).__init__(clsname, bases, clsdict) hasattr(cls, '__visit_name__'):
return _generate_dispatch(cls)
# set up an optimized visit dispatch function
# for use by the compiler
visit_name = cls.__visit_name__
if isinstance(visit_name, str):
getter = operator.attrgetter("visit_%s" % visit_name)
def _compiler_dispatch(self, visitor, **kw):
return getter(visitor)(self, **kw)
else:
def _compiler_dispatch(self, visitor, **kw):
return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
cls._compiler_dispatch = _compiler_dispatch
super(VisitableType, cls).__init__(clsname, bases, clsdict) super(VisitableType, cls).__init__(clsname, bases, clsdict)
class Visitable(object):
def _generate_dispatch(cls):
"""Return an optimized visit dispatch function for the cls
for use by the compiler.
"""
if '__visit_name__' in cls.__dict__:
visit_name = cls.__visit_name__
if isinstance(visit_name, str):
# There is an optimization opportunity here because the
# the string name of the class's __visit_name__ is known at
# this early stage (import time) so it can be pre-constructed.
getter = operator.attrgetter("visit_%s" % visit_name)
def _compiler_dispatch(self, visitor, **kw):
try:
meth = getter(visitor)
except AttributeError:
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
else:
# The optimization opportunity is lost for this case because the
# __visit_name__ is not yet a string. As a result, the visit
# string has to be recalculated with each compilation.
def _compiler_dispatch(self, visitor, **kw):
visit_attr = 'visit_%s' % self.__visit_name__
try:
meth = getattr(visitor, visit_attr)
except AttributeError:
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
_compiler_dispatch.__doc__ = \
"""Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.
"""
cls._compiler_dispatch = _compiler_dispatch
class Visitable(util.with_metaclass(VisitableType, object)):
"""Base class for visitable objects, applies the """Base class for visitable objects, applies the
``VisitableType`` metaclass. ``VisitableType`` metaclass.
""" """
__metaclass__ = VisitableType
class ClauseVisitor(object): class ClauseVisitor(object):
"""Base class for visitor objects which can traverse using """Base class for visitor objects which can traverse using
the traverse() function. the traverse() function.
""" """
__traverse_options__ = {} __traverse_options__ = {}
def traverse_single(self, obj): def traverse_single(self, obj, **kw):
for v in self._visitor_iterator: for v in self._visitor_iterator:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None) meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth: if meth:
return meth(obj) return meth(obj, **kw)
def iterate(self, obj):
"""traverse the given expression structure, returning an iterator of all elements."""
def iterate(self, obj):
"""traverse the given expression structure, returning an iterator
of all elements.
"""
return iterate(obj, self.__traverse_options__) return iterate(obj, self.__traverse_options__)
def traverse(self, obj): def traverse(self, obj):
"""traverse and visit the given expression structure.""" """traverse and visit the given expression structure."""
return traverse(obj, self.__traverse_options__, self._visitor_dict) return traverse(obj, self.__traverse_options__, self._visitor_dict)
@util.memoized_property @util.memoized_property
def _visitor_dict(self): def _visitor_dict(self):
visitors = {} visitors = {}
@ -93,11 +140,11 @@ class ClauseVisitor(object):
if name.startswith('visit_'): if name.startswith('visit_'):
visitors[name[6:]] = getattr(self, name) visitors[name[6:]] = getattr(self, name)
return visitors return visitors
@property @property
def _visitor_iterator(self): def _visitor_iterator(self):
"""iterate through this visitor and each 'chained' visitor.""" """iterate through this visitor and each 'chained' visitor."""
v = self v = self
while v: while v:
yield v yield v
@ -105,41 +152,46 @@ class ClauseVisitor(object):
def chain(self, visitor): def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor. """'chain' an additional ClauseVisitor onto this ClauseVisitor.
the chained visitor will receive all visit events after this one. the chained visitor will receive all visit events after this one.
""" """
tail = list(self._visitor_iterator)[-1] tail = list(self._visitor_iterator)[-1]
tail._next = visitor tail._next = visitor
return self return self
class CloningVisitor(ClauseVisitor): class CloningVisitor(ClauseVisitor):
"""Base class for visitor objects which can traverse using """Base class for visitor objects which can traverse using
the cloned_traverse() function. the cloned_traverse() function.
""" """
def copy_and_process(self, list_): def copy_and_process(self, list_):
"""Apply cloned traversal to the given list of elements, and return the new list.""" """Apply cloned traversal to the given list of elements, and return
the new list.
"""
return [self.traverse(x) for x in list_] return [self.traverse(x) for x in list_]
def traverse(self, obj): def traverse(self, obj):
"""traverse and visit the given expression structure.""" """traverse and visit the given expression structure."""
return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict) return cloned_traverse(
obj, self.__traverse_options__, self._visitor_dict)
class ReplacingCloningVisitor(CloningVisitor): class ReplacingCloningVisitor(CloningVisitor):
"""Base class for visitor objects which can traverse using """Base class for visitor objects which can traverse using
the replacement_traverse() function. the replacement_traverse() function.
""" """
def replace(self, elem): def replace(self, elem):
"""receive pre-copied elements during a cloning traversal. """receive pre-copied elements during a cloning traversal.
If the method returns a new element, the element is used If the method returns a new element, the element is used
instead of creating a simple copy of the element. Traversal instead of creating a simple copy of the element. Traversal
will halt on the newly returned element if it is re-encountered. will halt on the newly returned element if it is re-encountered.
""" """
return None return None
@ -154,25 +206,39 @@ class ReplacingCloningVisitor(CloningVisitor):
return e return e
return replacement_traverse(obj, self.__traverse_options__, replace) return replacement_traverse(obj, self.__traverse_options__, replace)
def iterate(obj, opts): def iterate(obj, opts):
"""traverse the given expression structure, returning an iterator. """traverse the given expression structure, returning an iterator.
traversal is configured to be breadth-first. traversal is configured to be breadth-first.
""" """
# fasttrack for atomic elements like columns
children = obj.get_children(**opts)
if not children:
return [obj]
traversal = deque()
stack = deque([obj]) stack = deque([obj])
while stack: while stack:
t = stack.popleft() t = stack.popleft()
yield t traversal.append(t)
for c in t.get_children(**opts): for c in t.get_children(**opts):
stack.append(c) stack.append(c)
return iter(traversal)
def iterate_depthfirst(obj, opts): def iterate_depthfirst(obj, opts):
"""traverse the given expression structure, returning an iterator. """traverse the given expression structure, returning an iterator.
traversal is configured to be depth-first. traversal is configured to be depth-first.
""" """
# fasttrack for atomic elements like columns
children = obj.get_children(**opts)
if not children:
return [obj]
stack = deque([obj]) stack = deque([obj])
traversal = deque() traversal = deque()
while stack: while stack:
@ -182,75 +248,81 @@ def iterate_depthfirst(obj, opts):
stack.append(c) stack.append(c)
return iter(traversal) return iter(traversal)
def traverse_using(iterator, obj, visitors):
"""visit the given expression structure using the given iterator of objects."""
def traverse_using(iterator, obj, visitors):
"""visit the given expression structure using the given iterator of
objects.
"""
for target in iterator: for target in iterator:
meth = visitors.get(target.__visit_name__, None) meth = visitors.get(target.__visit_name__, None)
if meth: if meth:
meth(target) meth(target)
return obj return obj
def traverse(obj, opts, visitors):
"""traverse and visit the given expression structure using the default iterator."""
def traverse(obj, opts, visitors):
"""traverse and visit the given expression structure using the default
iterator.
"""
return traverse_using(iterate(obj, opts), obj, visitors) return traverse_using(iterate(obj, opts), obj, visitors)
def traverse_depthfirst(obj, opts, visitors):
"""traverse and visit the given expression structure using the depth-first iterator."""
def traverse_depthfirst(obj, opts, visitors):
"""traverse and visit the given expression structure using the
depth-first iterator.
"""
return traverse_using(iterate_depthfirst(obj, opts), obj, visitors) return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
def cloned_traverse(obj, opts, visitors): def cloned_traverse(obj, opts, visitors):
"""clone the given expression structure, allowing modifications by visitors.""" """clone the given expression structure, allowing
modifications by visitors."""
cloned = util.column_dict()
def clone(element): cloned = {}
if element not in cloned: stop_on = set(opts.get('stop_on', []))
cloned[element] = element._clone()
return cloned[element]
obj = clone(obj) def clone(elem):
stack = [obj] if elem in stop_on:
return elem
else:
if id(elem) not in cloned:
cloned[id(elem)] = newelem = elem._clone()
newelem._copy_internals(clone=clone)
meth = visitors.get(newelem.__visit_name__, None)
if meth:
meth(newelem)
return cloned[id(elem)]
while stack: if obj is not None:
t = stack.pop() obj = clone(obj)
if t in cloned:
continue
t._copy_internals(clone=clone)
meth = visitors.get(t.__visit_name__, None)
if meth:
meth(t)
for c in t.get_children(**opts):
stack.append(c)
return obj return obj
def replacement_traverse(obj, opts, replace): def replacement_traverse(obj, opts, replace):
"""clone the given expression structure, allowing element replacement by a given replacement function.""" """clone the given expression structure, allowing element
replacement by a given replacement function."""
cloned = util.column_dict()
stop_on = util.column_set(opts.get('stop_on', []))
def clone(element): cloned = {}
newelem = replace(element) stop_on = set([id(x) for x in opts.get('stop_on', [])])
if newelem is not None:
stop_on.add(newelem)
return newelem
if element not in cloned: def clone(elem, **kw):
cloned[element] = element._clone() if id(elem) in stop_on or \
return cloned[element] 'no_replacement_traverse' in elem._annotations:
return elem
else:
newelem = replace(elem)
if newelem is not None:
stop_on.add(id(newelem))
return newelem
else:
if elem not in cloned:
cloned[elem] = newelem = elem._clone()
newelem._copy_internals(clone=clone, **kw)
return cloned[elem]
obj = clone(obj) if obj is not None:
stack = [obj] obj = clone(obj, **opts)
while stack:
t = stack.pop()
if t in stop_on:
continue
t._copy_internals(clone=clone)
for c in t.get_children(**opts):
stack.append(c)
return obj return obj

File diff suppressed because it is too large Load Diff