diff --git a/sqlalchemy/__init__.py b/sqlalchemy/__init__.py index 376b13e..a2116e0 100644 --- a/sqlalchemy/__init__.py +++ b/sqlalchemy/__init__.py @@ -1,24 +1,23 @@ -# __init__.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# sqlalchemy/__init__.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import inspect -import sys -import sqlalchemy.exc as exceptions -sys.modules['sqlalchemy.exceptions'] = exceptions - -from sqlalchemy.sql import ( +from .sql import ( alias, + all_, and_, + any_, asc, between, bindparam, case, cast, collate, + column, delete, desc, distinct, @@ -26,11 +25,14 @@ from sqlalchemy.sql import ( except_all, exists, extract, + false, func, + funcfilter, insert, intersect, intersect_all, join, + lateral, literal, literal_column, modifier, @@ -39,16 +41,25 @@ from sqlalchemy.sql import ( or_, outerjoin, outparam, + over, select, subquery, + table, + tablesample, text, + true, tuple_, + type_coerce, union, union_all, update, + within_group, ) -from sqlalchemy.types import ( +from .types import ( + ARRAY, + BIGINT, + BINARY, BLOB, BOOLEAN, BigInteger, @@ -68,12 +79,14 @@ from sqlalchemy.types import ( INTEGER, Integer, Interval, + JSON, LargeBinary, NCHAR, NVARCHAR, NUMERIC, Numeric, PickleType, + REAL, SMALLINT, SmallInteger, String, @@ -82,18 +95,19 @@ from sqlalchemy.types import ( TIMESTAMP, Text, Time, + TypeDecorator, Unicode, UnicodeText, + VARBINARY, VARCHAR, ) -from sqlalchemy.schema import ( +from .schema import ( CheckConstraint, Column, ColumnDefault, Constraint, - DDL, DefaultClause, FetchedValue, ForeignKey, @@ -106,14 +120,27 @@ from sqlalchemy.schema import ( Table, ThreadLocalMetaData, UniqueConstraint, - ) - -from sqlalchemy.engine import create_engine, engine_from_config + DDL, + BLANK_SCHEMA +) -__all__ = sorted(name for name, obj in locals().items() - if not (name.startswith('_') or inspect.ismodule(obj))) - -__version__ = '0.6beta3' +from .inspection import inspect +from .engine import create_engine, engine_from_config -del inspect, sys +__version__ = '1.1.9' + + +def __go(lcls): + global __all__ + + from . import events + from . import util as _sa_util + + import inspect as _inspect + + __all__ = sorted(name for name, obj in lcls.items() + if not (name.startswith('_') or _inspect.ismodule(obj))) + + _sa_util.dependencies.resolve_all("sqlalchemy") +__go(locals()) diff --git a/sqlalchemy/connectors/__init__.py b/sqlalchemy/connectors/__init__.py index f1383ad..5cf06d8 100644 --- a/sqlalchemy/connectors/__init__.py +++ b/sqlalchemy/connectors/__init__.py @@ -1,6 +1,10 @@ +# connectors/__init__.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php class Connector(object): pass - - \ No newline at end of file diff --git a/sqlalchemy/connectors/mxodbc.py b/sqlalchemy/connectors/mxodbc.py index 816474d..32e7e18 100644 --- a/sqlalchemy/connectors/mxodbc.py +++ b/sqlalchemy/connectors/mxodbc.py @@ -1,5 +1,12 @@ +# connectors/mxodbc.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + """ -Provide an SQLALchemy connector for the eGenix mxODBC commercial +Provide a SQLALchemy connector for the eGenix mxODBC commercial Python adapter for ODBC. This is not a free product, but eGenix provides SQLAlchemy with a license for use in continuous integration testing. @@ -15,21 +22,19 @@ For more info on mxODBC, see http://www.egenix.com/ import sys import re import warnings -from decimal import Decimal -from sqlalchemy.connectors import Connector -from sqlalchemy import types as sqltypes -import sqlalchemy.processors as processors +from . import Connector + class MxODBCConnector(Connector): - driver='mxodbc' - + driver = 'mxodbc' + supports_sane_multi_rowcount = False - supports_unicode_statements = False - supports_unicode_binds = False - + supports_unicode_statements = True + supports_unicode_binds = True + supports_native_decimal = True - + @classmethod def dbapi(cls): # this classmethod will normally be replaced by an instance @@ -44,7 +49,7 @@ class MxODBCConnector(Connector): elif platform == 'darwin': from mx.ODBC import iODBC as module else: - raise ImportError, "Unrecognized platform for mxODBC import" + raise ImportError("Unrecognized platform for mxODBC import") return module @classmethod @@ -64,21 +69,21 @@ class MxODBCConnector(Connector): conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT conn.errorhandler = self._error_handler() return connect - + def _error_handler(self): """ Return a handler that adjusts mxODBC's raised Warnings to emit Python standard warnings. """ from mx.ODBC.Error import Warning as MxOdbcWarning - def error_handler(connection, cursor, errorclass, errorvalue): + def error_handler(connection, cursor, errorclass, errorvalue): if issubclass(errorclass, MxOdbcWarning): errorclass.__bases__ = (Warning,) warnings.warn(message=str(errorvalue), - category=errorclass, - stacklevel=2) + category=errorclass, + stacklevel=2) else: - raise errorclass, errorvalue + raise errorclass(errorvalue) return error_handler def create_connect_args(self, url): @@ -94,7 +99,7 @@ class MxODBCConnector(Connector): The arg 'errorhandler' is not used by SQLAlchemy and will not be populated. - + """ opts = url.translate_connect_args(username='user') opts.update(url.query) @@ -103,9 +108,9 @@ class MxODBCConnector(Connector): opts.pop('database', None) return (args,), opts - def is_disconnect(self, e): - # eGenix recommends checking connection.closed here, - # but how can we get a handle on the current connection? + def is_disconnect(self, e, connection, cursor): + # TODO: eGenix recommends checking connection.closed here + # Does that detect dropped connections ? if isinstance(e, self.dbapi.ProgrammingError): return "connection already closed" in str(e) elif isinstance(e, self.dbapi.Error): @@ -114,10 +119,11 @@ class MxODBCConnector(Connector): return False def _get_server_version_info(self, connection): - # eGenix suggests using conn.dbms_version instead of what we're doing here + # eGenix suggests using conn.dbms_version instead + # of what we're doing here dbapi_con = connection.connection version = [] - r = re.compile('[.\-]') + r = re.compile(r'[.\-]') # 18 == pyodbc.SQL_DBMS_VER for n in r.split(dbapi_con.getinfo(18)[1]): try: @@ -126,21 +132,19 @@ class MxODBCConnector(Connector): version.append(n) return tuple(version) - def do_execute(self, cursor, statement, parameters, context=None): + def _get_direct(self, context): if context: native_odbc_execute = context.execution_options.\ - get('native_odbc_execute', 'auto') - if native_odbc_execute is True: - # user specified native_odbc_execute=True - cursor.execute(statement, parameters) - elif native_odbc_execute is False: - # user specified native_odbc_execute=False - cursor.executedirect(statement, parameters) - elif context.is_crud: - # statement is UPDATE, DELETE, INSERT - cursor.execute(statement, parameters) - else: - # all other statements - cursor.executedirect(statement, parameters) + get('native_odbc_execute', 'auto') + # default to direct=True in all cases, is more generally + # compatible especially with SQL Server + return False if native_odbc_execute is True else True else: - cursor.executedirect(statement, parameters) + return True + + def do_executemany(self, cursor, statement, parameters, context=None): + cursor.executemany( + statement, parameters, direct=self._get_direct(context)) + + def do_execute(self, cursor, statement, parameters, context=None): + cursor.execute(statement, parameters, direct=self._get_direct(context)) diff --git a/sqlalchemy/connectors/pyodbc.py b/sqlalchemy/connectors/pyodbc.py index b291f3e..ee8445d 100644 --- a/sqlalchemy/connectors/pyodbc.py +++ b/sqlalchemy/connectors/pyodbc.py @@ -1,29 +1,51 @@ -from sqlalchemy.connectors import Connector -from sqlalchemy.util import asbool +# connectors/pyodbc.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from . import Connector +from .. import util + import sys import re -import urllib -import decimal + class PyODBCConnector(Connector): - driver='pyodbc' + driver = 'pyodbc' supports_sane_multi_rowcount = False - # PyODBC unicode is broken on UCS-4 builds - supports_unicode = sys.maxunicode == 65535 - supports_unicode_statements = supports_unicode + + if util.py2k: + # PyODBC unicode is broken on UCS-4 builds + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = supports_unicode + supports_native_decimal = True default_paramstyle = 'named' - - # for non-DSN connections, this should + + # for non-DSN connections, this *may* be used to # hold the desired driver name pyodbc_driver_name = None - + # will be set to True after initialize() # if the freetds.so is detected freetds = False - + + # will be set to the string version of + # the FreeTDS driver if freetds is detected + freetds_driver_version = None + + # will be set to True after initialize() + # if the libessqlsrv.so is detected + easysoft = False + + def __init__(self, supports_unicode_binds=None, **kw): + super(PyODBCConnector, self).__init__(**kw) + self._user_supports_unicode_binds = supports_unicode_binds + @classmethod def dbapi(cls): return __import__('pyodbc') @@ -31,29 +53,53 @@ class PyODBCConnector(Connector): def create_connect_args(self, url): opts = url.translate_connect_args(username='user') opts.update(url.query) - + keys = opts + query = url.query connect_args = {} for param in ('ansi', 'unicode_results', 'autocommit'): if param in keys: - connect_args[param] = asbool(keys.pop(param)) + connect_args[param] = util.asbool(keys.pop(param)) if 'odbc_connect' in keys: - connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))] + connectors = [util.unquote_plus(keys.pop('odbc_connect'))] else: - dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys) + def check_quote(token): + if ";" in str(token): + token = "'%s'" % token + return token + + keys = dict( + (k, check_quote(v)) for k, v in keys.items() + ) + + dsn_connection = 'dsn' in keys or \ + ('host' in keys and 'database' not in keys) if dsn_connection: - connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))] + connectors = ['dsn=%s' % (keys.pop('host', '') or + keys.pop('dsn', ''))] else: port = '' - if 'port' in keys and not 'port' in query: + if 'port' in keys and 'port' not in query: port = ',%d' % int(keys.pop('port')) - connectors = ["DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name), - 'Server=%s%s' % (keys.pop('host', ''), port), - 'Database=%s' % keys.pop('database', '') ] + connectors = [] + driver = keys.pop('driver', self.pyodbc_driver_name) + if driver is None: + util.warn( + "No driver name specified; " + "this is expected by PyODBC when using " + "DSN-less connections") + else: + connectors.append("DRIVER={%s}" % driver) + + connectors.extend( + [ + 'Server=%s%s' % (keys.pop('host', ''), port), + 'Database=%s' % keys.pop('database', '') + ]) user = keys.pop("user", None) if user: @@ -62,20 +108,22 @@ class PyODBCConnector(Connector): else: connectors.append("Trusted_Connection=Yes") - # if set to 'Yes', the ODBC layer will try to automagically convert - # textual data from your database encoding to your client encoding - # This should obviously be set to 'No' if you query a cp1253 encoded - # database from a latin1 client... + # if set to 'Yes', the ODBC layer will try to automagically + # convert textual data from your database encoding to your + # client encoding. This should obviously be set to 'No' if + # you query a cp1253 encoded database from a latin1 client... if 'odbc_autotranslate' in keys: - connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate")) + connectors.append("AutoTranslate=%s" % + keys.pop("odbc_autotranslate")) - connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()]) - return [[";".join (connectors)], connect_args] - - def is_disconnect(self, e): + connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()]) + + return [[";".join(connectors)], connect_args] + + def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.ProgrammingError): return "The cursor's connection has been closed." in str(e) or \ - 'Attempt to use a closed connection.' in str(e) + 'Attempt to use a closed connection.' in str(e) elif isinstance(e, self.dbapi.Error): return '[08S01]' in str(e) else: @@ -84,27 +132,62 @@ class PyODBCConnector(Connector): def initialize(self, connection): # determine FreeTDS first. can't issue SQL easily # without getting unicode_statements/binds set up. - + pyodbc = self.dbapi dbapi_con = connection.connection - self.freetds = bool(re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME))) + _sql_driver_name = dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME) + self.freetds = bool(re.match(r".*libtdsodbc.*\.so", _sql_driver_name + )) + self.easysoft = bool(re.match(r".*libessqlsrv.*\.so", _sql_driver_name + )) + + if self.freetds: + self.freetds_driver_version = dbapi_con.getinfo( + pyodbc.SQL_DRIVER_VER) + + self.supports_unicode_statements = ( + not util.py2k or + (not self.freetds and not self.easysoft) + ) + + if self._user_supports_unicode_binds is not None: + self.supports_unicode_binds = self._user_supports_unicode_binds + elif util.py2k: + self.supports_unicode_binds = ( + not self.freetds or self.freetds_driver_version >= '0.91' + ) and not self.easysoft + else: + self.supports_unicode_binds = True - # the "Py2K only" part here is theoretical. - # have not tried pyodbc + python3.1 yet. - # Py2K - self.supports_unicode_statements = not self.freetds - self.supports_unicode_binds = not self.freetds - # end Py2K - # run other initialization which asks for user name, etc. super(PyODBCConnector, self).initialize(connection) + def _dbapi_version(self): + if not self.dbapi: + return () + return self._parse_dbapi_version(self.dbapi.version) + + def _parse_dbapi_version(self, vers): + m = re.match( + r'(?:py.*-)?([\d\.]+)(?:-(\w+))?', + vers + ) + if not m: + return () + vers = tuple([int(x) for x in m.group(1).split(".")]) + if m.group(2): + vers += (m.group(2),) + return vers + def _get_server_version_info(self, connection): + # NOTE: this function is not reliable, particularly when + # freetds is in use. Implement database-specific server version + # queries. dbapi_con = connection.connection version = [] - r = re.compile('[.\-]') + r = re.compile(r'[.\-]') for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): try: version.append(int(n)) diff --git a/sqlalchemy/connectors/zxJDBC.py b/sqlalchemy/connectors/zxJDBC.py index ae43128..8a5b749 100644 --- a/sqlalchemy/connectors/zxJDBC.py +++ b/sqlalchemy/connectors/zxJDBC.py @@ -1,20 +1,28 @@ +# connectors/zxJDBC.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + import sys -from sqlalchemy.connectors import Connector +from . import Connector + class ZxJDBCConnector(Connector): driver = 'zxjdbc' - + supports_sane_rowcount = False supports_sane_multi_rowcount = False - + supports_unicode_binds = True supports_unicode_statements = sys.version > '2.5.0+' description_encoding = None default_paramstyle = 'qmark' - + jdbc_db_name = None jdbc_driver_name = None - + @classmethod def dbapi(cls): from com.ziclix.python.sql import zxJDBC @@ -23,20 +31,24 @@ class ZxJDBCConnector(Connector): def _driver_kwargs(self): """Return kw arg dict to be sent to connect().""" return {} - + def _create_jdbc_url(self, url): """Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`""" return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host, - url.port is not None and ':%s' % url.port or '', + url.port is not None + and ':%s' % url.port or '', url.database) - + def create_connect_args(self, url): opts = self._driver_kwargs() opts.update(url.query) - return [[self._create_jdbc_url(url), url.username, url.password, self.jdbc_driver_name], - opts] + return [ + [self._create_jdbc_url(url), + url.username, url.password, + self.jdbc_driver_name], + opts] - def is_disconnect(self, e): + def is_disconnect(self, e, connection, cursor): if not isinstance(e, self.dbapi.ProgrammingError): return False e = str(e) diff --git a/sqlalchemy/dialects/__init__.py b/sqlalchemy/dialects/__init__.py index 91ca91f..44051f0 100644 --- a/sqlalchemy/dialects/__init__.py +++ b/sqlalchemy/dialects/__init__.py @@ -1,12 +1,56 @@ +# dialects/__init__.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + __all__ = ( -# 'access', -# 'firebird', -# 'informix', -# 'maxdb', -# 'mssql', + 'firebird', + 'mssql', 'mysql', 'oracle', 'postgresql', 'sqlite', -# 'sybase', - ) + 'sybase', +) + +from .. import util + +_translates = {'postgres': 'postgresql'} + +def _auto_fn(name): + """default dialect importer. + + plugs into the :class:`.PluginLoader` + as a first-hit system. + + """ + if "." in name: + dialect, driver = name.split(".") + else: + dialect = name + driver = "base" + + if dialect in _translates: + translated = _translates[dialect] + util.warn_deprecated( + "The '%s' dialect name has been " + "renamed to '%s'" % (dialect, translated) + ) + dialect = translated + try: + module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects + except ImportError: + return None + + module = getattr(module, dialect) + if hasattr(module, driver): + module = getattr(module, driver) + return lambda: module.dialect + else: + return None + +registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn) + +plugins = util.PluginLoader("sqlalchemy.plugins") \ No newline at end of file diff --git a/sqlalchemy/dialects/postgresql/__init__.py b/sqlalchemy/dialects/postgresql/__init__.py index 6aca1e1..a6872cf 100644 --- a/sqlalchemy/dialects/postgresql/__init__.py +++ b/sqlalchemy/dialects/postgresql/__init__.py @@ -1,14 +1,36 @@ -from sqlalchemy.dialects.postgresql import base, psycopg2, pg8000, pypostgresql, zxjdbc +# postgresql/__init__.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from . import base, psycopg2, pg8000, pypostgresql, pygresql, \ + zxjdbc, psycopg2cffi base.dialect = psycopg2.dialect -from sqlalchemy.dialects.postgresql.base import \ - INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, INET, \ - CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME,\ - DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect +from .base import \ + INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \ + INET, CIDR, UUID, BIT, MACADDR, OID, DOUBLE_PRECISION, TIMESTAMP, TIME, \ + DATE, BYTEA, BOOLEAN, INTERVAL, ENUM, dialect, TSVECTOR, DropEnumType, \ + CreateEnumType +from .hstore import HSTORE, hstore +from .json import JSON, JSONB +from .array import array, ARRAY, Any, All +from .ext import aggregate_order_by, ExcludeConstraint, array_agg +from .dml import insert, Insert + +from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \ + TSTZRANGE __all__ = ( -'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', 'FLOAT', 'REAL', 'INET', -'CIDR', 'UUID', 'BIT', 'MACADDR', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', -'DATE', 'BYTEA', 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect' + 'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', + 'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR', 'OID', + 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN', + 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'array', 'HSTORE', + 'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE', + 'TSRANGE', 'TSTZRANGE', 'json', 'JSON', 'JSONB', 'Any', 'All', + 'DropEnumType', 'CreateEnumType', 'ExcludeConstraint', + 'aggregate_order_by', 'array_agg', 'insert', 'Insert' ) diff --git a/sqlalchemy/dialects/postgresql/base.py b/sqlalchemy/dialects/postgresql/base.py index bef2f1c..26d974e 100644 --- a/sqlalchemy/dialects/postgresql/base.py +++ b/sqlalchemy/dialects/postgresql/base.py @@ -1,123 +1,969 @@ -# postgresql.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# postgresql/base.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Support for the PostgreSQL database. +r""" +.. dialect:: postgresql + :name: PostgreSQL -For information on connecting using specific drivers, see the documentation section -regarding that driver. +.. _postgresql_sequences: Sequences/SERIAL ---------------- -PostgreSQL supports sequences, and SQLAlchemy uses these as the default means of creating -new primary key values for integer-based primary key columns. When creating tables, -SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns, -which generates a sequence corresponding to the column and associated with it based on -a naming convention. +PostgreSQL supports sequences, and SQLAlchemy uses these as the default means +of creating new primary key values for integer-based primary key columns. When +creating tables, SQLAlchemy will issue the ``SERIAL`` datatype for +integer-based primary key columns, which generates a sequence and server side +default corresponding to the column. -To specify a specific named sequence to be used for primary key generation, use the -:func:`~sqlalchemy.schema.Sequence` construct:: +To specify a specific named sequence to be used for primary key generation, +use the :func:`~sqlalchemy.schema.Sequence` construct:: - Table('sometable', metadata, + Table('sometable', metadata, Column('id', Integer, Sequence('some_id_seq'), primary_key=True) ) -Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of -having the "last insert identifier" available, the sequence is executed independently -beforehand and the new value is retrieved, to be used in the subsequent insert. Note -that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using -"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior -is used. +When SQLAlchemy issues a single INSERT statement, to fulfill the contract of +having the "last insert identifier" available, a RETURNING clause is added to +the INSERT statement which specifies the primary key columns should be +returned after the statement completes. The RETURNING functionality only takes +place if PostgreSQL 8.2 or later is in use. As a fallback approach, the +sequence, whether specified explicitly or implicitly via ``SERIAL``, is +executed independently beforehand, the returned value to be used in the +subsequent insert. Note that when an +:func:`~sqlalchemy.sql.expression.insert()` construct is executed using +"executemany" semantics, the "last inserted identifier" functionality does not +apply; no RETURNING clause is emitted nor is the sequence pre-executed in this +case. -PostgreSQL 8.2 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports -as well. A future release of SQLA will use this feature by default in lieu of -sequence pre-execution in order to retrieve new primary key values, when available. +To force the usage of RETURNING by default off, specify the flag +``implicit_returning=False`` to :func:`.create_engine`. + +.. _postgresql_isolation_level: + +Transaction Isolation Level +--------------------------- + +All PostgreSQL dialects support setting of transaction isolation level +both via a dialect-specific parameter +:paramref:`.create_engine.isolation_level` accepted by :func:`.create_engine`, +as well as the :paramref:`.Connection.execution_options.isolation_level` +argument as passed to :meth:`.Connection.execution_options`. +When using a non-psycopg2 dialect, this feature works by issuing the command +``SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL `` for +each new connection. For the special AUTOCOMMIT isolation level, +DBAPI-specific techniques are used. + +To set isolation level using :func:`.create_engine`:: + + engine = create_engine( + "postgresql+pg8000://scott:tiger@localhost/test", + isolation_level="READ UNCOMMITTED" + ) + +To set using per-connection execution options:: + + connection = engine.connect() + connection = connection.execution_options( + isolation_level="READ COMMITTED" + ) + +Valid values for ``isolation_level`` include: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` - on psycopg2 / pg8000 only + +.. seealso:: + + :ref:`psycopg2_isolation_level` + + :ref:`pg8000_isolation_level` + +.. _postgresql_schema_reflection: + +Remote-Schema Table Introspection and PostgreSQL search_path +------------------------------------------------------------ + +The PostgreSQL dialect can reflect tables from any schema. The +:paramref:`.Table.schema` argument, or alternatively the +:paramref:`.MetaData.reflect.schema` argument determines which schema will +be searched for the table or tables. The reflected :class:`.Table` objects +will in all cases retain this ``.schema`` attribute as was specified. +However, with regards to tables which these :class:`.Table` objects refer to +via foreign key constraint, a decision must be made as to how the ``.schema`` +is represented in those remote tables, in the case where that remote +schema name is also a member of the current +`PostgreSQL search path +`_. + +By default, the PostgreSQL dialect mimics the behavior encouraged by +PostgreSQL's own ``pg_get_constraintdef()`` builtin procedure. This function +returns a sample definition for a particular foreign key constraint, +omitting the referenced schema name from that definition when the name is +also in the PostgreSQL schema search path. The interaction below +illustrates this behavior:: + + test=> CREATE TABLE test_schema.referred(id INTEGER PRIMARY KEY); + CREATE TABLE + test=> CREATE TABLE referring( + test(> id INTEGER PRIMARY KEY, + test(> referred_id INTEGER REFERENCES test_schema.referred(id)); + CREATE TABLE + test=> SET search_path TO public, test_schema; + test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM + test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n + test-> ON n.oid = c.relnamespace + test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid + test-> WHERE c.relname='referring' AND r.contype = 'f' + test-> ; + pg_get_constraintdef + --------------------------------------------------- + FOREIGN KEY (referred_id) REFERENCES referred(id) + (1 row) + +Above, we created a table ``referred`` as a member of the remote schema +``test_schema``, however when we added ``test_schema`` to the +PG ``search_path`` and then asked ``pg_get_constraintdef()`` for the +``FOREIGN KEY`` syntax, ``test_schema`` was not included in the output of +the function. + +On the other hand, if we set the search path back to the typical default +of ``public``:: + + test=> SET search_path TO public; + SET + +The same query against ``pg_get_constraintdef()`` now returns the fully +schema-qualified name for us:: + + test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM + test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n + test-> ON n.oid = c.relnamespace + test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid + test-> WHERE c.relname='referring' AND r.contype = 'f'; + pg_get_constraintdef + --------------------------------------------------------------- + FOREIGN KEY (referred_id) REFERENCES test_schema.referred(id) + (1 row) + +SQLAlchemy will by default use the return value of ``pg_get_constraintdef()`` +in order to determine the remote schema name. That is, if our ``search_path`` +were set to include ``test_schema``, and we invoked a table +reflection process as follows:: + + >>> from sqlalchemy import Table, MetaData, create_engine + >>> engine = create_engine("postgresql://scott:tiger@localhost/test") + >>> with engine.connect() as conn: + ... conn.execute("SET search_path TO test_schema, public") + ... meta = MetaData() + ... referring = Table('referring', meta, + ... autoload=True, autoload_with=conn) + ... + + +The above process would deliver to the :attr:`.MetaData.tables` collection +``referred`` table named **without** the schema:: + + >>> meta.tables['referred'].schema is None + True + +To alter the behavior of reflection such that the referred schema is +maintained regardless of the ``search_path`` setting, use the +``postgresql_ignore_search_path`` option, which can be specified as a +dialect-specific argument to both :class:`.Table` as well as +:meth:`.MetaData.reflect`:: + + >>> with engine.connect() as conn: + ... conn.execute("SET search_path TO test_schema, public") + ... meta = MetaData() + ... referring = Table('referring', meta, autoload=True, + ... autoload_with=conn, + ... postgresql_ignore_search_path=True) + ... + + +We will now have ``test_schema.referred`` stored as schema-qualified:: + + >>> meta.tables['test_schema.referred'].schema + 'test_schema' + +.. sidebar:: Best Practices for PostgreSQL Schema reflection + + The description of PostgreSQL schema reflection behavior is complex, and + is the product of many years of dealing with widely varied use cases and + user preferences. But in fact, there's no need to understand any of it if + you just stick to the simplest use pattern: leave the ``search_path`` set + to its default of ``public`` only, never refer to the name ``public`` as + an explicit schema name otherwise, and refer to all other schema names + explicitly when building up a :class:`.Table` object. The options + described here are only for those users who can't, or prefer not to, stay + within these guidelines. + +Note that **in all cases**, the "default" schema is always reflected as +``None``. The "default" schema on PostgreSQL is that which is returned by the +PostgreSQL ``current_schema()`` function. On a typical PostgreSQL +installation, this is the name ``public``. So a table that refers to another +which is in the ``public`` (i.e. default) schema will always have the +``.schema`` attribute set to ``None``. + +.. versionadded:: 0.9.2 Added the ``postgresql_ignore_search_path`` + dialect-level option accepted by :class:`.Table` and + :meth:`.MetaData.reflect`. + + +.. seealso:: + + `The Schema Search Path + `_ + - on the PostgreSQL website. INSERT/UPDATE...RETURNING ------------------------- -The dialect supports PG 8.2's ``INSERT..RETURNING``, ``UPDATE..RETURNING`` and ``DELETE..RETURNING`` syntaxes, -but must be explicitly enabled on a per-statement basis:: +The dialect supports PG 8.2's ``INSERT..RETURNING``, ``UPDATE..RETURNING`` and +``DELETE..RETURNING`` syntaxes. ``INSERT..RETURNING`` is used by default +for single-row INSERT statements in order to fetch newly generated +primary key identifiers. To specify an explicit ``RETURNING`` clause, +use the :meth:`._UpdateBase.returning` method on a per-statement basis:: # INSERT..RETURNING - result = table.insert().returning(table.c.col1, table.c.col2).\\ + result = table.insert().returning(table.c.col1, table.c.col2).\ values(name='foo') print result.fetchall() - + # UPDATE..RETURNING - result = table.update().returning(table.c.col1, table.c.col2).\\ + result = table.update().returning(table.c.col1, table.c.col2).\ where(table.c.name=='foo').values(name='bar') print result.fetchall() # DELETE..RETURNING - result = table.delete().returning(table.c.col1, table.c.col2).\\ + result = table.delete().returning(table.c.col1, table.c.col2).\ where(table.c.name=='foo') print result.fetchall() -Indexes -------- +.. _postgresql_insert_on_conflict: -PostgreSQL supports partial indexes. To create them pass a postgresql_where -option to the Index constructor:: +INSERT...ON CONFLICT (Upsert) +------------------------------ + +Starting with version 9.5, PostgreSQL allows "upserts" (update or insert) +of rows into a table via the ``ON CONFLICT`` clause of the ``INSERT`` statement. +A candidate row will only be inserted if that row does not violate +any unique constraints. In the case of a unique constraint violation, +a secondary action can occur which can be either "DO UPDATE", indicating +that the data in the target row should be updated, or "DO NOTHING", +which indicates to silently skip this row. + +Conflicts are determined using existing unique constraints and indexes. These +constraints may be identified either using their name as stated in DDL, +or they may be *inferred* by stating the columns and conditions that comprise +the indexes. + +SQLAlchemy provides ``ON CONFLICT`` support via the PostgreSQL-specific +:func:`.postgresql.dml.insert()` function, which provides +the generative methods :meth:`~.postgresql.dml.Insert.on_conflict_do_update` +and :meth:`~.postgresql.dml.Insert.on_conflict_do_nothing`:: + + from sqlalchemy.dialects.postgresql import insert + + insert_stmt = insert(my_table).values( + id='some_existing_id', + data='inserted value') + + do_nothing_stmt = insert_stmt.on_conflict_do_nothing( + index_elements=['id'] + ) + + conn.execute(do_nothing_stmt) + + do_update_stmt = insert_stmt.on_conflict_do_update( + constraint='pk_my_table', + set_=dict(data='updated value') + ) + + conn.execute(do_update_stmt) + +Both methods supply the "target" of the conflict using either the +named constraint or by column inference: + +* The :paramref:`.Insert.on_conflict_do_update.index_elements` argument + specifies a sequence containing string column names, :class:`.Column` objects, + and/or SQL expression elements, which would identify a unique index:: + + do_update_stmt = insert_stmt.on_conflict_do_update( + index_elements=['id'], + set_=dict(data='updated value') + ) + + do_update_stmt = insert_stmt.on_conflict_do_update( + index_elements=[my_table.c.id], + set_=dict(data='updated value') + ) + +* When using :paramref:`.Insert.on_conflict_do_update.index_elements` to + infer an index, a partial index can be inferred by also specifying the + use the :paramref:`.Insert.on_conflict_do_update.index_where` parameter:: + + from sqlalchemy.dialects.postgresql import insert + + stmt = insert(my_table).values(user_email='a@b.com', data='inserted data') + stmt = stmt.on_conflict_do_update( + index_elements=[my_table.c.user_email], + index_where=my_table.c.user_email.like('%@gmail.com'), + set_=dict(data=stmt.excluded.data) + ) + conn.execute(stmt) + + +* The :paramref:`.Insert.on_conflict_do_update.constraint` argument is + used to specify an index directly rather than inferring it. This can be + the name of a UNIQUE constraint, a PRIMARY KEY constraint, or an INDEX:: + + do_update_stmt = insert_stmt.on_conflict_do_update( + constraint='my_table_idx_1', + set_=dict(data='updated value') + ) + + do_update_stmt = insert_stmt.on_conflict_do_update( + constraint='my_table_pk', + set_=dict(data='updated value') + ) + +* The :paramref:`.Insert.on_conflict_do_update.constraint` argument may + also refer to a SQLAlchemy construct representing a constraint, + e.g. :class:`.UniqueConstraint`, :class:`.PrimaryKeyConstraint`, + :class:`.Index`, or :class:`.ExcludeConstraint`. In this use, + if the constraint has a name, it is used directly. Otherwise, if the + constraint is unnamed, then inference will be used, where the expressions + and optional WHERE clause of the constraint will be spelled out in the + construct. This use is especially convenient + to refer to the named or unnamed primary key of a :class:`.Table` using the + :attr:`.Table.primary_key` attribute:: + + do_update_stmt = insert_stmt.on_conflict_do_update( + constraint=my_table.primary_key, + set_=dict(data='updated value') + ) + +``ON CONFLICT...DO UPDATE`` is used to perform an update of the already +existing row, using any combination of new values as well as values +from the proposed insertion. These values are specified using the +:paramref:`.Insert.on_conflict_do_update.set_` parameter. This +parameter accepts a dictionary which consists of direct values +for UPDATE:: + + from sqlalchemy.dialects.postgresql import insert + + stmt = insert(my_table).values(id='some_id', data='inserted value') + do_update_stmt = stmt.on_conflict_do_update( + index_elements=['id'], + set_=dict(data='updated value') + ) + conn.execute(do_update_stmt) + +.. warning:: + + The :meth:`.Insert.on_conflict_do_update` method does **not** take into + account Python-side default UPDATE values or generation functions, e.g. + e.g. those specified using :paramref:`.Column.onupdate`. + These values will not be exercised for an ON CONFLICT style of UPDATE, + unless they are manually specified in the + :paramref:`.Insert.on_conflict_do_update.set_` dictionary. + +In order to refer to the proposed insertion row, the special alias +:attr:`~.postgresql.dml.Insert.excluded` is available as an attribute on +the :class:`.postgresql.dml.Insert` object; this object is a +:class:`.ColumnCollection` which alias contains all columns of the target +table:: + + from sqlalchemy.dialects.postgresql import insert + + stmt = insert(my_table).values( + id='some_id', + data='inserted value', + author='jlh') + do_update_stmt = stmt.on_conflict_do_update( + index_elements=['id'], + set_=dict(data='updated value', author=stmt.excluded.author) + ) + conn.execute(do_update_stmt) + +The :meth:`.Insert.on_conflict_do_update` method also accepts +a WHERE clause using the :paramref:`.Insert.on_conflict_do_update.where` +parameter, which will limit those rows which receive an UPDATE:: + + from sqlalchemy.dialects.postgresql import insert + + stmt = insert(my_table).values( + id='some_id', + data='inserted value', + author='jlh') + on_update_stmt = stmt.on_conflict_do_update( + index_elements=['id'], + set_=dict(data='updated value', author=stmt.excluded.author) + where=(my_table.c.status == 2) + ) + conn.execute(on_update_stmt) + +``ON CONFLICT`` may also be used to skip inserting a row entirely +if any conflict with a unique or exclusion constraint occurs; below +this is illustrated using the +:meth:`~.postgresql.dml.Insert.on_conflict_do_nothing` method:: + + from sqlalchemy.dialects.postgresql import insert + + stmt = insert(my_table).values(id='some_id', data='inserted value') + stmt = stmt.on_conflict_do_nothing(index_elements=['id']) + conn.execute(stmt) + +If ``DO NOTHING`` is used without specifying any columns or constraint, +it has the effect of skipping the INSERT for any unique or exclusion +constraint violation which occurs:: + + from sqlalchemy.dialects.postgresql import insert + + stmt = insert(my_table).values(id='some_id', data='inserted value') + stmt = stmt.on_conflict_do_nothing() + conn.execute(stmt) + +.. versionadded:: 1.1 Added support for PostgreSQL ON CONFLICT clauses + +.. seealso:: + + `INSERT .. ON CONFLICT `_ - in the PostgreSQL documentation. + +.. _postgresql_match: + +Full Text Search +---------------- + +SQLAlchemy makes available the PostgreSQL ``@@`` operator via the +:meth:`.ColumnElement.match` method on any textual column expression. +On a PostgreSQL dialect, an expression like the following:: + + select([sometable.c.text.match("search string")]) + +will emit to the database:: + + SELECT text @@ to_tsquery('search string') FROM table + +The PostgreSQL text search functions such as ``to_tsquery()`` +and ``to_tsvector()`` are available +explicitly using the standard :data:`.func` construct. For example:: + + select([ + func.to_tsvector('fat cats ate rats').match('cat & rat') + ]) + +Emits the equivalent of:: + + SELECT to_tsvector('fat cats ate rats') @@ to_tsquery('cat & rat') + +The :class:`.postgresql.TSVECTOR` type can provide for explicit CAST:: + + from sqlalchemy.dialects.postgresql import TSVECTOR + from sqlalchemy import select, cast + select([cast("some text", TSVECTOR)]) + +produces a statement equivalent to:: + + SELECT CAST('some text' AS TSVECTOR) AS anon_1 + +Full Text Searches in PostgreSQL are influenced by a combination of: the +PostgresSQL setting of ``default_text_search_config``, the ``regconfig`` used +to build the GIN/GiST indexes, and the ``regconfig`` optionally passed in +during a query. + +When performing a Full Text Search against a column that has a GIN or +GiST index that is already pre-computed (which is common on full text +searches) one may need to explicitly pass in a particular PostgresSQL +``regconfig`` value to ensure the query-planner utilizes the index and does +not re-compute the column on demand. + +In order to provide for this explicit query planning, or to use different +search strategies, the ``match`` method accepts a ``postgresql_regconfig`` +keyword argument:: + + select([mytable.c.id]).where( + mytable.c.title.match('somestring', postgresql_regconfig='english') + ) + +Emits the equivalent of:: + + SELECT mytable.id FROM mytable + WHERE mytable.title @@ to_tsquery('english', 'somestring') + +One can also specifically pass in a `'regconfig'` value to the +``to_tsvector()`` command as the initial argument:: + + select([mytable.c.id]).where( + func.to_tsvector('english', mytable.c.title )\ + .match('somestring', postgresql_regconfig='english') + ) + +produces a statement equivalent to:: + + SELECT mytable.id FROM mytable + WHERE to_tsvector('english', mytable.title) @@ + to_tsquery('english', 'somestring') + +It is recommended that you use the ``EXPLAIN ANALYZE...`` tool from +PostgresSQL to ensure that you are generating queries with SQLAlchemy that +take full advantage of any indexes you may have created for full text search. + +FROM ONLY ... +------------------------ + +The dialect supports PostgreSQL's ONLY keyword for targeting only a particular +table in an inheritance hierarchy. This can be used to produce the +``SELECT ... FROM ONLY``, ``UPDATE ONLY ...``, and ``DELETE FROM ONLY ...`` +syntaxes. It uses SQLAlchemy's hints mechanism:: + + # SELECT ... FROM ONLY ... + result = table.select().with_hint(table, 'ONLY', 'postgresql') + print result.fetchall() + + # UPDATE ONLY ... + table.update(values=dict(foo='bar')).with_hint('ONLY', + dialect_name='postgresql') + + # DELETE FROM ONLY ... + table.delete().with_hint('ONLY', dialect_name='postgresql') + + +.. _postgresql_indexes: + +PostgreSQL-Specific Index Options +--------------------------------- + +Several extensions to the :class:`.Index` construct are available, specific +to the PostgreSQL dialect. + +.. _postgresql_partial_indexes: + +Partial Indexes +^^^^^^^^^^^^^^^^ + +Partial indexes add criterion to the index definition so that the index is +applied to a subset of rows. These can be specified on :class:`.Index` +using the ``postgresql_where`` keyword argument:: + + Index('my_index', my_table.c.id, postgresql_where=my_table.c.value > 10) + +Operator Classes +^^^^^^^^^^^^^^^^^ + +PostgreSQL allows the specification of an *operator class* for each column of +an index (see +http://www.postgresql.org/docs/8.3/interactive/indexes-opclass.html). +The :class:`.Index` construct allows these to be specified via the +``postgresql_ops`` keyword argument:: + + Index('my_index', my_table.c.id, my_table.c.data, + postgresql_ops={ + 'data': 'text_pattern_ops', + 'id': 'int4_ops' + }) + +.. versionadded:: 0.7.2 + ``postgresql_ops`` keyword argument to :class:`.Index` construct. + +Note that the keys in the ``postgresql_ops`` dictionary are the "key" name of +the :class:`.Column`, i.e. the name used to access it from the ``.c`` +collection of :class:`.Table`, which can be configured to be different than +the actual name of the column as expressed in the database. + +Index Types +^^^^^^^^^^^^ + +PostgreSQL provides several index types: B-Tree, Hash, GiST, and GIN, as well +as the ability for users to create their own (see +http://www.postgresql.org/docs/8.3/static/indexes-types.html). These can be +specified on :class:`.Index` using the ``postgresql_using`` keyword argument:: + + Index('my_index', my_table.c.data, postgresql_using='gin') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX command, so it *must* be a valid index type for your +version of PostgreSQL. + +.. _postgresql_index_storage: + +Index Storage Parameters +^^^^^^^^^^^^^^^^^^^^^^^^ + +PostgreSQL allows storage parameters to be set on indexes. The storage +parameters available depend on the index method used by the index. Storage +parameters can be specified on :class:`.Index` using the ``postgresql_with`` +keyword argument:: + + Index('my_index', my_table.c.data, postgresql_with={"fillfactor": 50}) + +.. versionadded:: 1.0.6 + +PostgreSQL allows to define the tablespace in which to create the index. +The tablespace can be specified on :class:`.Index` using the +``postgresql_tablespace`` keyword argument:: + + Index('my_index', my_table.c.data, postgresql_tablespace='my_tablespace') + +.. versionadded:: 1.1 + +Note that the same option is available on :class:`.Table` as well. + +.. _postgresql_index_concurrently: + +Indexes with CONCURRENTLY +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The PostgreSQL index option CONCURRENTLY is supported by passing the +flag ``postgresql_concurrently`` to the :class:`.Index` construct:: + + tbl = Table('testtbl', m, Column('data', Integer)) + + idx1 = Index('test_idx1', tbl.c.data, postgresql_concurrently=True) + +The above index construct will render DDL for CREATE INDEX, assuming +PostgreSQL 8.2 or higher is detected or for a connection-less dialect, as:: + + CREATE INDEX CONCURRENTLY test_idx1 ON testtbl (data) + +For DROP INDEX, assuming PostgreSQL 9.2 or higher is detected or for +a connection-less dialect, it will emit:: + + DROP INDEX CONCURRENTLY test_idx1 + +.. versionadded:: 1.1 support for CONCURRENTLY on DROP INDEX. The + CONCURRENTLY keyword is now only emitted if a high enough version + of PostgreSQL is detected on the connection (or for a connection-less + dialect). + +When using CONCURRENTLY, the Postgresql database requires that the statement +be invoked outside of a transaction block. The Python DBAPI enforces that +even for a single statement, a transaction is present, so to use this +construct, the DBAPI's "autocommit" mode must be used:: + + metadata = MetaData() + table = Table( + "foo", metadata, + Column("id", String)) + index = Index( + "foo_idx", table.c.id, postgresql_concurrently=True) + + with engine.connect() as conn: + with conn.execution_options(isolation_level='AUTOCOMMIT'): + table.create(conn) + +.. seealso:: + + :ref:`postgresql_isolation_level` + +.. _postgresql_index_reflection: + +PostgreSQL Index Reflection +--------------------------- + +The PostgreSQL database creates a UNIQUE INDEX implicitly whenever the +UNIQUE CONSTRAINT construct is used. When inspecting a table using +:class:`.Inspector`, the :meth:`.Inspector.get_indexes` +and the :meth:`.Inspector.get_unique_constraints` will report on these +two constructs distinctly; in the case of the index, the key +``duplicates_constraint`` will be present in the index entry if it is +detected as mirroring a constraint. When performing reflection using +``Table(..., autoload=True)``, the UNIQUE INDEX is **not** returned +in :attr:`.Table.indexes` when it is detected as mirroring a +:class:`.UniqueConstraint` in the :attr:`.Table.constraints` collection. + +.. versionchanged:: 1.0.0 - :class:`.Table` reflection now includes + :class:`.UniqueConstraint` objects present in the :attr:`.Table.constraints` + collection; the PostgreSQL backend will no longer include a "mirrored" + :class:`.Index` construct in :attr:`.Table.indexes` if it is detected + as corresponding to a unique constraint. + +Special Reflection Options +-------------------------- + +The :class:`.Inspector` used for the PostgreSQL backend is an instance +of :class:`.PGInspector`, which offers additional methods:: + + from sqlalchemy import create_engine, inspect + + engine = create_engine("postgresql+psycopg2://localhost/test") + insp = inspect(engine) # will be a PGInspector + + print(insp.get_enums()) + +.. autoclass:: PGInspector + :members: + +.. _postgresql_table_options: + +PostgreSQL Table Options +------------------------- + +Several options for CREATE TABLE are supported directly by the PostgreSQL +dialect in conjunction with the :class:`.Table` construct: + +* ``TABLESPACE``:: + + Table("some_table", metadata, ..., postgresql_tablespace='some_tablespace') + + The above option is also available on the :class:`.Index` construct. + +* ``ON COMMIT``:: + + Table("some_table", metadata, ..., postgresql_on_commit='PRESERVE ROWS') + +* ``WITH OIDS``:: + + Table("some_table", metadata, ..., postgresql_with_oids=True) + +* ``WITHOUT OIDS``:: + + Table("some_table", metadata, ..., postgresql_with_oids=False) + +* ``INHERITS``:: + + Table("some_table", metadata, ..., postgresql_inherits="some_supertable") + + Table("some_table", metadata, ..., postgresql_inherits=("t1", "t2", ...)) + +.. versionadded:: 1.0.0 + +.. seealso:: + + `PostgreSQL CREATE TABLE options + `_ + +ARRAY Types +----------- + +The PostgreSQL dialect supports arrays, both as multidimensional column types +as well as array literals: + +* :class:`.postgresql.ARRAY` - ARRAY datatype + +* :class:`.postgresql.array` - array literal + +* :func:`.postgresql.array_agg` - ARRAY_AGG SQL function + +* :class:`.postgresql.aggregate_order_by` - helper for PG's ORDER BY aggregate + function syntax. + +JSON Types +---------- + +The PostgreSQL dialect supports both JSON and JSONB datatypes, including +psycopg2's native support and support for all of PostgreSQL's special +operators: + +* :class:`.postgresql.JSON` + +* :class:`.postgresql.JSONB` + +HSTORE Type +----------- + +The PostgreSQL HSTORE type as well as hstore literals are supported: + +* :class:`.postgresql.HSTORE` - HSTORE datatype + +* :class:`.postgresql.hstore` - hstore literal + +ENUM Types +---------- + +PostgreSQL has an independently creatable TYPE structure which is used +to implement an enumerated type. This approach introduces significant +complexity on the SQLAlchemy side in terms of when this type should be +CREATED and DROPPED. The type object is also an independently reflectable +entity. The following sections should be consulted: + +* :class:`.postgresql.ENUM` - DDL and typing support for ENUM. + +* :meth:`.PGInspector.get_enums` - retrieve a listing of current ENUM types + +* :meth:`.postgresql.ENUM.create` , :meth:`.postgresql.ENUM.drop` - individual + CREATE and DROP commands for ENUM. + +.. _postgresql_array_of_enum: + +Using ENUM with ARRAY +^^^^^^^^^^^^^^^^^^^^^ + +The combination of ENUM and ARRAY is not directly supported by backend +DBAPIs at this time. In order to send and receive an ARRAY of ENUM, +use the following workaround type:: + + class ArrayOfEnum(ARRAY): + + def bind_expression(self, bindvalue): + return sa.cast(bindvalue, self) + + def result_processor(self, dialect, coltype): + super_rp = super(ArrayOfEnum, self).result_processor( + dialect, coltype) + + def handle_raw_string(value): + inner = re.match(r"^{(.*)}$", value).group(1) + return inner.split(",") if inner else [] + + def process(value): + if value is None: + return None + return super_rp(handle_raw_string(value)) + return process + +E.g.:: + + Table( + 'mydata', metadata, + Column('id', Integer, primary_key=True), + Column('data', ArrayOfEnum(ENUM('a', 'b, 'c', name='myenum'))) + + ) + +This type is not included as a built-in type as it would be incompatible +with a DBAPI that suddenly decides to support ARRAY of ENUM directly in +a new version. + +.. _postgresql_array_of_json: + +Using JSON/JSONB with ARRAY +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Similar to using ENUM, for an ARRAY of JSON/JSONB we need to render the +appropriate CAST, however current psycopg2 drivers seem to handle the result +for ARRAY of JSON automatically, so the type is simpler:: + + + class CastingArray(ARRAY): + def bind_expression(self, bindvalue): + return sa.cast(bindvalue, self) + +E.g.:: + + Table( + 'mydata', metadata, + Column('id', Integer, primary_key=True), + Column('data', CastingArray(JSONB)) + ) - Index('my_index', my_table.c.id, postgresql_where=tbl.c.value > 10) """ - +from collections import defaultdict import re +import datetime as dt -from sqlalchemy import schema as sa_schema -from sqlalchemy import sql, schema, exc, util -from sqlalchemy.engine import base, default, reflection -from sqlalchemy.sql import compiler, expression, util as sql_util -from sqlalchemy.sql import operators as sql_operators -from sqlalchemy import types as sqltypes + +from sqlalchemy.sql import elements +from ... import sql, schema, exc, util +from ...engine import default, reflection +from ...sql import compiler, expression +from ... import types as sqltypes + +try: + from uuid import UUID as _python_UUID +except ImportError: + _python_UUID = None from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \ - CHAR, TEXT, FLOAT, NUMERIC, \ - DATE, BOOLEAN + CHAR, TEXT, FLOAT, NUMERIC, \ + DATE, BOOLEAN, REAL -class REAL(sqltypes.Float): - __visit_name__ = "REAL" +AUTOCOMMIT_REGEXP = re.compile( + r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|' + 'IMPORT FOREIGN SCHEMA|REFRESH MATERIALIZED VIEW)', + re.I | re.UNICODE) + +RESERVED_WORDS = set( + ["all", "analyse", "analyze", "and", "any", "array", "as", "asc", + "asymmetric", "both", "case", "cast", "check", "collate", "column", + "constraint", "create", "current_catalog", "current_date", + "current_role", "current_time", "current_timestamp", "current_user", + "default", "deferrable", "desc", "distinct", "do", "else", "end", + "except", "false", "fetch", "for", "foreign", "from", "grant", "group", + "having", "in", "initially", "intersect", "into", "leading", "limit", + "localtime", "localtimestamp", "new", "not", "null", "of", "off", + "offset", "old", "on", "only", "or", "order", "placing", "primary", + "references", "returning", "select", "session_user", "some", "symmetric", + "table", "then", "to", "trailing", "true", "union", "unique", "user", + "using", "variadic", "when", "where", "window", "with", "authorization", + "between", "binary", "cross", "current_schema", "freeze", "full", + "ilike", "inner", "is", "isnull", "join", "left", "like", "natural", + "notnull", "outer", "over", "overlaps", "right", "similar", "verbose" + ]) + +_DECIMAL_TYPES = (1231, 1700) +_FLOAT_TYPES = (700, 701, 1021, 1022) +_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) class BYTEA(sqltypes.LargeBinary): __visit_name__ = 'BYTEA' + class DOUBLE_PRECISION(sqltypes.Float): __visit_name__ = 'DOUBLE_PRECISION' - + + class INET(sqltypes.TypeEngine): __visit_name__ = "INET" PGInet = INET + class CIDR(sqltypes.TypeEngine): __visit_name__ = "CIDR" PGCidr = CIDR + class MACADDR(sqltypes.TypeEngine): __visit_name__ = "MACADDR" PGMacAddr = MACADDR + +class OID(sqltypes.TypeEngine): + + """Provide the PostgreSQL OID type. + + .. versionadded:: 0.9.5 + + """ + __visit_name__ = "OID" + + class TIMESTAMP(sqltypes.TIMESTAMP): + def __init__(self, timezone=False, precision=None): super(TIMESTAMP, self).__init__(timezone=timezone) self.precision = precision - + + class TIME(sqltypes.TIME): + def __init__(self, timezone=False, precision=None): super(TIME, self).__init__(timezone=timezone) self.precision = precision - + + class INTERVAL(sqltypes.TypeEngine): + + """PostgreSQL INTERVAL type. + + The INTERVAL type may not be supported on all DBAPIs. + It is known to work on psycopg2 and not pg8000 or zxjdbc. + + """ __visit_name__ = 'INTERVAL' + def __init__(self, precision=None): self.precision = precision - - def adapt(self, impltype): - return impltype(self.precision) @classmethod def _adapt_from_generic_interval(cls, interval): @@ -126,272 +972,627 @@ class INTERVAL(sqltypes.TypeEngine): @property def _type_affinity(self): return sqltypes.Interval - + + @property + def python_type(self): + return dt.timedelta + PGInterval = INTERVAL + class BIT(sqltypes.TypeEngine): __visit_name__ = 'BIT' + + def __init__(self, length=None, varying=False): + if not varying: + # BIT without VARYING defaults to length 1 + self.length = length or 1 + else: + # but BIT VARYING can be unlimited-length, so no default + self.length = length + self.varying = varying + PGBit = BIT + class UUID(sqltypes.TypeEngine): + + """PostgreSQL UUID type. + + Represents the UUID column type, interpreting + data either as natively returned by the DBAPI + or as Python uuid objects. + + The UUID type may not be supported on all DBAPIs. + It is known to work on psycopg2 and not pg8000. + + """ __visit_name__ = 'UUID' -PGUuid = UUID -class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): - __visit_name__ = 'ARRAY' - - def __init__(self, item_type, mutable=True): - """Construct an ARRAY. + def __init__(self, as_uuid=False): + """Construct a UUID type. - E.g.:: - Column('myarray', ARRAY(Integer)) + :param as_uuid=False: if True, values will be interpreted + as Python uuid objects, converting to/from string via the + DBAPI. - Arguments are: + """ + if as_uuid and _python_UUID is None: + raise NotImplementedError( + "This version of Python does not support " + "the native UUID type." + ) + self.as_uuid = as_uuid - :param item_type: The data type of items of this array. Note that dimensionality is - irrelevant here, so multi-dimensional arrays like ``INTEGER[][]``, are constructed as - ``ARRAY(Integer)``, not as ``ARRAY(ARRAY(Integer))`` or such. The type mapping figures - out on the fly - - :param mutable: Defaults to True: specify whether lists passed to this class should be - considered mutable. If so, generic copy operations (typically used by the ORM) will - shallow-copy values. - - """ - if isinstance(item_type, ARRAY): - raise ValueError("Do not nest ARRAY types; ARRAY(basetype) " - "handles multi-dimensional arrays of basetype") - if isinstance(item_type, type): - item_type = item_type() - self.item_type = item_type - self.mutable = mutable - - def copy_value(self, value): - if value is None: - return None - elif self.mutable: - return list(value) - else: - return value - - def compare_values(self, x, y): - return x == y - - def is_mutable(self): - return self.mutable - - def dialect_impl(self, dialect, **kwargs): - impl = super(ARRAY, self).dialect_impl(dialect, **kwargs) - if impl is self: - impl = self.__class__.__new__(self.__class__) - impl.__dict__.update(self.__dict__) - impl.item_type = self.item_type.dialect_impl(dialect) - return impl - - def adapt(self, impltype): - return impltype( - self.item_type, - mutable=self.mutable - ) - def bind_processor(self, dialect): - item_proc = self.item_type.bind_processor(dialect) - if item_proc: - def convert_item(item): - if isinstance(item, (list, tuple)): - return [convert_item(child) for child in item] - else: - return item_proc(item) - else: - def convert_item(item): - if isinstance(item, (list, tuple)): - return [convert_item(child) for child in item] - else: - return item - def process(value): - if value is None: + if self.as_uuid: + def process(value): + if value is not None: + value = util.text_type(value) return value - return [convert_item(item) for item in value] - return process + return process + else: + return None def result_processor(self, dialect, coltype): - item_proc = self.item_type.result_processor(dialect, coltype) - if item_proc: - def convert_item(item): - if isinstance(item, list): - return [convert_item(child) for child in item] - else: - return item_proc(item) - else: - def convert_item(item): - if isinstance(item, list): - return [convert_item(child) for child in item] - else: - return item - def process(value): - if value is None: + if self.as_uuid: + def process(value): + if value is not None: + value = _python_UUID(value) return value - return [convert_item(item) for item in value] - return process -PGArray = ARRAY + return process + else: + return None + +PGUuid = UUID + + +class TSVECTOR(sqltypes.TypeEngine): + + """The :class:`.postgresql.TSVECTOR` type implements the PostgreSQL + text search type TSVECTOR. + + It can be used to do full text queries on natural language + documents. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`postgresql_match` + + """ + __visit_name__ = 'TSVECTOR' + class ENUM(sqltypes.Enum): + """PostgreSQL ENUM type. + + This is a subclass of :class:`.types.Enum` which includes + support for PG's ``CREATE TYPE`` and ``DROP TYPE``. + + When the builtin type :class:`.types.Enum` is used and the + :paramref:`.Enum.native_enum` flag is left at its default of + True, the PostgreSQL backend will use a :class:`.postgresql.ENUM` + type as the implementation, so the special create/drop rules + will be used. + + The create/drop behavior of ENUM is necessarily intricate, due to the + awkward relationship the ENUM type has in relationship to the + parent table, in that it may be "owned" by just a single table, or + may be shared among many tables. + + When using :class:`.types.Enum` or :class:`.postgresql.ENUM` + in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted + corresponding to when the :meth:`.Table.create` and :meth:`.Table.drop` + methods are called:: + + table = Table('sometable', metadata, + Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) + ) + + table.create(engine) # will emit CREATE ENUM and CREATE TABLE + table.drop(engine) # will emit DROP TABLE and DROP ENUM + + To use a common enumerated type between multiple tables, the best + practice is to declare the :class:`.types.Enum` or + :class:`.postgresql.ENUM` independently, and associate it with the + :class:`.MetaData` object itself:: + + my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) + + t1 = Table('sometable_one', metadata, + Column('some_enum', myenum) + ) + + t2 = Table('sometable_two', metadata, + Column('some_enum', myenum) + ) + + When this pattern is used, care must still be taken at the level + of individual table creates. Emitting CREATE TABLE without also + specifying ``checkfirst=True`` will still cause issues:: + + t1.create(engine) # will fail: no such type 'myenum' + + If we specify ``checkfirst=True``, the individual table-level create + operation will check for the ``ENUM`` and create if not exists:: + + # will check if enum exists, and emit CREATE TYPE if not + t1.create(engine, checkfirst=True) + + When using a metadata-level ENUM type, the type will always be created + and dropped if either the metadata-wide create/drop is called:: + + metadata.create_all(engine) # will emit CREATE TYPE + metadata.drop_all(engine) # will emit DROP TYPE + + The type can also be created and dropped directly:: + + my_enum.create(engine) + my_enum.drop(engine) + + .. versionchanged:: 1.0.0 The PostgreSQL :class:`.postgresql.ENUM` type + now behaves more strictly with regards to CREATE/DROP. A metadata-level + ENUM type will only be created and dropped at the metadata level, + not the table level, with the exception of + ``table.create(checkfirst=True)``. + The ``table.drop()`` call will now emit a DROP TYPE for a table-level + enumerated type. + + """ + + def __init__(self, *enums, **kw): + """Construct an :class:`~.postgresql.ENUM`. + + Arguments are the same as that of + :class:`.types.Enum`, but also including + the following parameters. + + :param create_type: Defaults to True. + Indicates that ``CREATE TYPE`` should be + emitted, after optionally checking for the + presence of the type, when the parent + table is being created; and additionally + that ``DROP TYPE`` is called when the table + is dropped. When ``False``, no check + will be performed and no ``CREATE TYPE`` + or ``DROP TYPE`` is emitted, unless + :meth:`~.postgresql.ENUM.create` + or :meth:`~.postgresql.ENUM.drop` + are called directly. + Setting to ``False`` is helpful + when invoking a creation scheme to a SQL file + without access to the actual database - + the :meth:`~.postgresql.ENUM.create` and + :meth:`~.postgresql.ENUM.drop` methods can + be used to emit SQL to a target bind. + + .. versionadded:: 0.7.4 + + """ + self.create_type = kw.pop("create_type", True) + super(ENUM, self).__init__(*enums, **kw) + def create(self, bind=None, checkfirst=True): - if not checkfirst or not bind.dialect.has_type(bind, self.name, schema=self.schema): + """Emit ``CREATE TYPE`` for this + :class:`~.postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL CREATE TYPE, no action is taken. + + :param bind: a connectable :class:`.Engine`, + :class:`.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type does not exist already before + creating. + + """ + if not bind.dialect.supports_native_enum: + return + + if not checkfirst or \ + not bind.dialect.has_type( + bind, self.name, schema=self.schema): bind.execute(CreateEnumType(self)) def drop(self, bind=None, checkfirst=True): - if not checkfirst or bind.dialect.has_type(bind, self.name, schema=self.schema): + """Emit ``DROP TYPE`` for this + :class:`~.postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL DROP TYPE, no action is taken. + + :param bind: a connectable :class:`.Engine`, + :class:`.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type actually exists before dropping. + + """ + if not bind.dialect.supports_native_enum: + return + + if not checkfirst or \ + bind.dialect.has_type(bind, self.name, schema=self.schema): bind.execute(DropEnumType(self)) - - def _on_table_create(self, event, target, bind, **kw): - self.create(bind=bind, checkfirst=True) - def _on_metadata_create(self, event, target, bind, **kw): - if self.metadata is not None: - self.create(bind=bind, checkfirst=True) + def _check_for_name_in_memos(self, checkfirst, kw): + """Look in the 'ddl runner' for 'memos', then + note our name in that collection. - def _on_metadata_drop(self, event, target, bind, **kw): - self.drop(bind=bind, checkfirst=True) + This to ensure a particular named enum is operated + upon only once within any kind of create/drop + sequence without relying upon "checkfirst". + + """ + if not self.create_type: + return True + if '_ddl_runner' in kw: + ddl_runner = kw['_ddl_runner'] + if '_pg_enums' in ddl_runner.memo: + pg_enums = ddl_runner.memo['_pg_enums'] + else: + pg_enums = ddl_runner.memo['_pg_enums'] = set() + present = self.name in pg_enums + pg_enums.add(self.name) + return present + else: + return False + + def _on_table_create(self, target, bind, checkfirst=False, **kw): + if checkfirst or ( + not self.metadata and + not kw.get('_is_metadata_operation', False)) and \ + not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_table_drop(self, target, bind, checkfirst=False, **kw): + if not self.metadata and \ + not kw.get('_is_metadata_operation', False) and \ + not self._check_for_name_in_memos(checkfirst, kw): + self.drop(bind=bind, checkfirst=checkfirst) + + def _on_metadata_create(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.drop(bind=bind, checkfirst=checkfirst) colspecs = { - sqltypes.Interval:INTERVAL, - sqltypes.Enum:ENUM, + sqltypes.Interval: INTERVAL, + sqltypes.Enum: ENUM, } ischema_names = { - 'integer' : INTEGER, - 'bigint' : BIGINT, - 'smallint' : SMALLINT, - 'character varying' : VARCHAR, - 'character' : CHAR, - '"char"' : sqltypes.String, - 'name' : sqltypes.String, - 'text' : TEXT, - 'numeric' : NUMERIC, - 'float' : FLOAT, - 'real' : REAL, + 'integer': INTEGER, + 'bigint': BIGINT, + 'smallint': SMALLINT, + 'character varying': VARCHAR, + 'character': CHAR, + '"char"': sqltypes.String, + 'name': sqltypes.String, + 'text': TEXT, + 'numeric': NUMERIC, + 'float': FLOAT, + 'real': REAL, 'inet': INET, 'cidr': CIDR, 'uuid': UUID, - 'bit':BIT, + 'bit': BIT, + 'bit varying': BIT, 'macaddr': MACADDR, - 'double precision' : DOUBLE_PRECISION, - 'timestamp' : TIMESTAMP, - 'timestamp with time zone' : TIMESTAMP, - 'timestamp without time zone' : TIMESTAMP, - 'time with time zone' : TIME, - 'time without time zone' : TIME, - 'date' : DATE, + 'oid': OID, + 'double precision': DOUBLE_PRECISION, + 'timestamp': TIMESTAMP, + 'timestamp with time zone': TIMESTAMP, + 'timestamp without time zone': TIMESTAMP, + 'time with time zone': TIME, + 'time without time zone': TIME, + 'date': DATE, 'time': TIME, - 'bytea' : BYTEA, - 'boolean' : BOOLEAN, - 'interval':INTERVAL, - 'interval year to month':INTERVAL, - 'interval day to second':INTERVAL, + 'bytea': BYTEA, + 'boolean': BOOLEAN, + 'interval': INTERVAL, + 'interval year to month': INTERVAL, + 'interval day to second': INTERVAL, + 'tsvector': TSVECTOR } - class PGCompiler(compiler.SQLCompiler): - - def visit_match_op(self, binary, **kw): - return "%s @@ to_tsquery(%s)" % (self.process(binary.left), self.process(binary.right)) - def visit_ilike_op(self, binary, **kw): - escape = binary.modifiers.get("escape", None) - return '%s ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \ - + (escape and ' ESCAPE \'%s\'' % escape or '') + def visit_array(self, element, **kw): + return "ARRAY[%s]" % self.visit_clauselist(element, **kw) - def visit_notilike_op(self, binary, **kw): + def visit_slice(self, element, **kw): + return "%s:%s" % ( + self.process(element.start, **kw), + self.process(element.stop, **kw), + ) + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + kw['eager_grouping'] = True + return self._generate_generic_binary( + binary, " -> ", **kw + ) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + kw['eager_grouping'] = True + return self._generate_generic_binary( + binary, " #> ", **kw + ) + + def visit_getitem_binary(self, binary, operator, **kw): + return "%s[%s]" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def visit_aggregate_order_by(self, element, **kw): + return "%s ORDER BY %s" % ( + self.process(element.target, **kw), + self.process(element.order_by, **kw) + ) + + def visit_match_op_binary(self, binary, operator, **kw): + if "postgresql_regconfig" in binary.modifiers: + regconfig = self.render_literal_value( + binary.modifiers['postgresql_regconfig'], + sqltypes.STRINGTYPE) + if regconfig: + return "%s @@ to_tsquery(%s, %s)" % ( + self.process(binary.left, **kw), + regconfig, + self.process(binary.right, **kw) + ) + return "%s @@ to_tsquery(%s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return '%s NOT ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \ - + (escape and ' ESCAPE \'%s\'' % escape or '') + + return '%s ILIKE %s' % \ + (self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) + + def visit_notilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return '%s NOT ILIKE %s' % \ + (self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) + + def render_literal_value(self, value, type_): + value = super(PGCompiler, self).render_literal_value(value, type_) + + if self.dialect._backslash_escapes: + value = value.replace('\\', '\\\\') + return value def visit_sequence(self, seq): - if seq.optional: - return None - else: - return "nextval('%s')" % self.preparer.format_sequence(seq) + return "nextval('%s')" % self.preparer.format_sequence(seq) - def limit_clause(self, select): + def limit_clause(self, select, **kw): text = "" - if select._limit is not None: - text += " \n LIMIT " + str(select._limit) - if select._offset is not None: - if select._limit is None: + if select._limit_clause is not None: + text += " \n LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + if select._limit_clause is None: text += " \n LIMIT ALL" - text += " OFFSET " + str(select._offset) + text += " OFFSET " + self.process(select._offset_clause, **kw) return text - def get_select_precolumns(self, select): + def format_from_hint_text(self, sqltext, table, hint, iscrud): + if hint.upper() != 'ONLY': + raise exc.CompileError("Unrecognized hint: %r" % hint) + return "ONLY " + sqltext + + def get_select_precolumns(self, select, **kw): if select._distinct is not False: if select._distinct is True: return "DISTINCT " elif isinstance(select._distinct, (list, tuple)): return "DISTINCT ON (" + ', '.join( - [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct] - )+ ") " + [self.process(col) for col in select._distinct] + ) + ") " else: - return "DISTINCT ON (" + unicode(select._distinct) + ") " + return "DISTINCT ON (" + \ + self.process(select._distinct, **kw) + ") " else: return "" - def for_update_clause(self, select): - if select.for_update == 'nowait': - return " FOR UPDATE NOWAIT" + def for_update_clause(self, select, **kw): + + if select._for_update_arg.read: + if select._for_update_arg.key_share: + tmp = " FOR KEY SHARE" + else: + tmp = " FOR SHARE" + elif select._for_update_arg.key_share: + tmp = " FOR NO KEY UPDATE" else: - return super(PGCompiler, self).for_update_clause(select) + tmp = " FOR UPDATE" + + if select._for_update_arg.of: + tables = util.OrderedSet( + c.table if isinstance(c, expression.ColumnClause) + else c for c in select._for_update_arg.of) + tmp += " OF " + ", ".join( + self.process(table, ashint=True, use_schema=False, **kw) + for table in tables + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + if select._for_update_arg.skip_locked: + tmp += " SKIP LOCKED" + + return tmp def returning_clause(self, stmt, returning_cols): - + columns = [ - self.process( - self.label_select_column(None, c, asfrom=False), - within_columns_clause=True, - result_map=self.result_map) - for c in expression._select_iterables(returning_cols) - ] - + self._label_select_column(None, c, True, False, {}) + for c in expression._select_iterables(returning_cols) + ] + return 'RETURNING ' + ', '.join(columns) - def visit_extract(self, extract, **kwargs): - field = self.extract_map.get(extract.field, extract.field) - if extract.expr.type: - affinity = extract.expr.type._type_affinity + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0], **kw) + start = self.process(func.clauses.clauses[1], **kw) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2], **kw) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) else: - affinity = None - - casts = { - sqltypes.Date:'date', - sqltypes.DateTime:'timestamp', - sqltypes.Interval:'interval', sqltypes.Time:'time' - } - cast = casts.get(affinity, None) - if isinstance(extract.expr, sql.ColumnElement) and cast is not None: - expr = extract.expr.op('::')(sql.literal_column(cast)) + return "SUBSTRING(%s FROM %s)" % (s, start) + + def _on_conflict_target(self, clause, **kw): + + if clause.constraint_target is not None: + target_text = 'ON CONSTRAINT %s' % clause.constraint_target + elif clause.inferred_target_elements is not None: + target_text = '(%s)' % ', '.join( + (self.preparer.quote(c) + if isinstance(c, util.string_types) + else + self.process(c, include_table=False, use_schema=False)) + for c in clause.inferred_target_elements + ) + if clause.inferred_target_whereclause is not None: + target_text += ' WHERE %s' % \ + self.process( + clause.inferred_target_whereclause, + include_table=False, + use_schema=False + ) else: - expr = extract.expr - return "EXTRACT(%s FROM %s)" % ( - field, self.process(expr)) + target_text = '' + + return target_text + + def visit_on_conflict_do_nothing(self, on_conflict, **kw): + + target_text = self._on_conflict_target(on_conflict, **kw) + + if target_text: + return "ON CONFLICT %s DO NOTHING" % target_text + else: + return "ON CONFLICT DO NOTHING" + + def visit_on_conflict_do_update(self, on_conflict, **kw): + + clause = on_conflict + + target_text = self._on_conflict_target(on_conflict, **kw) + + action_set_ops = [] + + set_parameters = dict(clause.update_values_to_set) + # create a list of column assignment clauses as tuples + cols = self.statement.table.c + for c in cols: + col_key = c.key + if col_key in set_parameters: + value = set_parameters.pop(col_key) + if elements._is_literal(value): + value = elements.BindParameter( + None, value, type_=c.type + ) + + else: + if isinstance(value, elements.BindParameter) and \ + value.type._isnull: + value = value._clone() + value.type = c.type + value_text = self.process(value.self_group(), use_schema=False) + + key_text = ( + self.preparer.quote(col_key) + ) + action_set_ops.append('%s = %s' % (key_text, value_text)) + + # check for names that don't match columns + if set_parameters: + util.warn( + "Additional column names not matching " + "any column keys in table '%s': %s" % ( + self.statement.table.name, + (", ".join("'%s'" % c for c in set_parameters)) + ) + ) + for k, v in set_parameters.items(): + key_text = ( + self.preparer.quote(k) + if isinstance(k, util.string_types) + else self.process(k, use_schema=False) + ) + value_text = self.process( + elements._literal_as_binds(v), + use_schema=False + ) + action_set_ops.append('%s = %s' % (key_text, value_text)) + + action_text = ', '.join(action_set_ops) + if clause.update_whereclause is not None: + action_text += ' WHERE %s' % \ + self.process( + clause.update_whereclause, + include_table=True, + use_schema=False + ) + + return 'ON CONFLICT %s DO UPDATE SET %s' % (target_text, action_text) + class PGDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + impl_type = column.type.dialect_impl(self.dialect) + if isinstance(impl_type, sqltypes.TypeDecorator): + impl_type = impl_type.impl + if column.primary_key and \ - len(column.foreign_keys)==0 and \ - column.autoincrement and \ - isinstance(column.type, sqltypes.Integer) and \ - not isinstance(column.type, sqltypes.SmallInteger) and \ - (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): - if isinstance(column.type, sqltypes.BigInteger): + column is column.table._autoincrement_column and \ + ( + self.dialect.supports_smallserial or + not isinstance(impl_type, sqltypes.SmallInteger) + ) and ( + column.default is None or + ( + isinstance(column.default, schema.Sequence) and + column.default.optional + )): + if isinstance(impl_type, sqltypes.BigInteger): colspec += " BIGSERIAL" + elif isinstance(impl_type, sqltypes.SmallInteger): + colspec += " SMALLSERIAL" else: colspec += " SERIAL" else: - colspec += " " + self.dialect.type_compiler.process(column.type) + colspec += " " + self.dialect.type_compiler.process( + column.type, type_expression=column) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -402,10 +1603,12 @@ class PGDDLCompiler(compiler.DDLCompiler): def visit_create_enum_type(self, create): type_ = create.element - + return "CREATE TYPE %s AS ENUM (%s)" % ( self.preparer.format_type(type_), - ",".join("'%s'" % e for e in type_.enums) + ", ".join( + self.sql_compiler.process(sql.literal(e), literal_binds=True) + for e in type_.enums) ) def visit_drop_enum_type(self, drop): @@ -414,188 +1617,418 @@ class PGDDLCompiler(compiler.DDLCompiler): return "DROP TYPE %s" % ( self.preparer.format_type(type_) ) - + def visit_create_index(self, create): preparer = self.preparer index = create.element + self._verify_index_table(index) text = "CREATE " if index.unique: text += "UNIQUE " - text += "INDEX %s ON %s (%s)" \ - % (preparer.quote(self._validate_identifier(index.name, True), index.quote), - preparer.format_table(index.table), - ', '.join([preparer.format_column(c) for c in index.columns])) - - if "postgres_where" in index.kwargs: - whereclause = index.kwargs['postgres_where'] - util.warn_deprecated("The 'postgres_where' argument has been renamed to 'postgresql_where'.") - elif 'postgresql_where' in index.kwargs: - whereclause = index.kwargs['postgresql_where'] - else: - whereclause = None - + text += "INDEX " + + if self.dialect._supports_create_index_concurrently: + concurrently = index.dialect_options['postgresql']['concurrently'] + if concurrently: + text += "CONCURRENTLY " + + text += "%s ON %s " % ( + self._prepared_index_name(index, + include_schema=False), + preparer.format_table(index.table) + ) + + using = index.dialect_options['postgresql']['using'] + if using: + text += "USING %s " % preparer.quote(using) + + ops = index.dialect_options["postgresql"]["ops"] + text += "(%s)" \ + % ( + ', '.join([ + self.sql_compiler.process( + expr.self_group() + if not isinstance(expr, expression.ColumnClause) + else expr, + include_table=False, literal_binds=True) + + ( + (' ' + ops[expr.key]) + if hasattr(expr, 'key') + and expr.key in ops else '' + ) + for expr in index.expressions + ]) + ) + + withclause = index.dialect_options['postgresql']['with'] + + if withclause: + text += " WITH (%s)" % (', '.join( + ['%s = %s' % storage_parameter + for storage_parameter in withclause.items()])) + + tablespace_name = index.dialect_options['postgresql']['tablespace'] + + if tablespace_name: + text += " TABLESPACE %s" % preparer.quote(tablespace_name) + + whereclause = index.dialect_options["postgresql"]["where"] + if whereclause is not None: - whereclause = sql_util.expression_as_ddl(whereclause) - where_compiled = self.sql_compiler.process(whereclause) + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, + literal_binds=True) text += " WHERE " + where_compiled return text + def visit_drop_index(self, drop): + index = drop.element + + text = "\nDROP INDEX " + + if self.dialect._supports_drop_index_concurrently: + concurrently = index.dialect_options['postgresql']['concurrently'] + if concurrently: + text += "CONCURRENTLY " + + text += self._prepared_index_name(index, include_schema=True) + return text + + def visit_exclude_constraint(self, constraint, **kw): + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + elements = [] + for expr, name, op in constraint._render_exprs: + kw['include_table'] = False + elements.append( + "%s WITH %s" % (self.sql_compiler.process(expr, **kw), op) + ) + text += "EXCLUDE USING %s (%s)" % (constraint.using, + ', '.join(elements)) + if constraint.where is not None: + text += ' WHERE (%s)' % self.sql_compiler.process( + constraint.where, + literal_binds=True) + text += self.define_constraint_deferrability(constraint) + return text + + def post_create_table(self, table): + table_opts = [] + pg_opts = table.dialect_options['postgresql'] + + inherits = pg_opts.get('inherits') + if inherits is not None: + if not isinstance(inherits, (list, tuple)): + inherits = (inherits, ) + table_opts.append( + '\n INHERITS ( ' + + ', '.join(self.preparer.quote(name) for name in inherits) + + ' )') + + if pg_opts['with_oids'] is True: + table_opts.append('\n WITH OIDS') + elif pg_opts['with_oids'] is False: + table_opts.append('\n WITHOUT OIDS') + + if pg_opts['on_commit']: + on_commit_options = pg_opts['on_commit'].replace("_", " ").upper() + table_opts.append('\n ON COMMIT %s' % on_commit_options) + + if pg_opts['tablespace']: + tablespace_name = pg_opts['tablespace'] + table_opts.append( + '\n TABLESPACE %s' % self.preparer.quote(tablespace_name) + ) + + return ''.join(table_opts) + class PGTypeCompiler(compiler.GenericTypeCompiler): - def visit_INET(self, type_): + def visit_TSVECTOR(self, type, **kw): + return "TSVECTOR" + + def visit_INET(self, type_, **kw): return "INET" - def visit_CIDR(self, type_): + def visit_CIDR(self, type_, **kw): return "CIDR" - def visit_MACADDR(self, type_): + def visit_MACADDR(self, type_, **kw): return "MACADDR" - def visit_FLOAT(self, type_): + def visit_OID(self, type_, **kw): + return "OID" + + def visit_FLOAT(self, type_, **kw): if not type_.precision: return "FLOAT" else: return "FLOAT(%(precision)s)" % {'precision': type_.precision} - - def visit_DOUBLE_PRECISION(self, type_): + + def visit_DOUBLE_PRECISION(self, type_, **kw): return "DOUBLE PRECISION" - - def visit_BIGINT(self, type_): + + def visit_BIGINT(self, type_, **kw): return "BIGINT" - def visit_datetime(self, type_): - return self.visit_TIMESTAMP(type_) - - def visit_enum(self, type_): + def visit_HSTORE(self, type_, **kw): + return "HSTORE" + + def visit_JSON(self, type_, **kw): + return "JSON" + + def visit_JSONB(self, type_, **kw): + return "JSONB" + + def visit_INT4RANGE(self, type_, **kw): + return "INT4RANGE" + + def visit_INT8RANGE(self, type_, **kw): + return "INT8RANGE" + + def visit_NUMRANGE(self, type_, **kw): + return "NUMRANGE" + + def visit_DATERANGE(self, type_, **kw): + return "DATERANGE" + + def visit_TSRANGE(self, type_, **kw): + return "TSRANGE" + + def visit_TSTZRANGE(self, type_, **kw): + return "TSTZRANGE" + + def visit_datetime(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) + + def visit_enum(self, type_, **kw): if not type_.native_enum or not self.dialect.supports_native_enum: - return super(PGTypeCompiler, self).visit_enum(type_) + return super(PGTypeCompiler, self).visit_enum(type_, **kw) else: - return self.visit_ENUM(type_) - - def visit_ENUM(self, type_): + return self.visit_ENUM(type_, **kw) + + def visit_ENUM(self, type_, **kw): return self.dialect.identifier_preparer.format_type(type_) - - def visit_TIMESTAMP(self, type_): + + def visit_TIMESTAMP(self, type_, **kw): return "TIMESTAMP%s %s" % ( - getattr(type_, 'precision', None) and "(%d)" % type_.precision or "", + "(%d)" % type_.precision + if getattr(type_, 'precision', None) is not None else "", (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" ) - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): return "TIME%s %s" % ( - getattr(type_, 'precision', None) and "(%d)" % type_.precision or "", + "(%d)" % type_.precision + if getattr(type_, 'precision', None) is not None else "", (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" ) - def visit_INTERVAL(self, type_): + def visit_INTERVAL(self, type_, **kw): if type_.precision is not None: return "INTERVAL(%d)" % type_.precision else: return "INTERVAL" - def visit_BIT(self, type_): - return "BIT" + def visit_BIT(self, type_, **kw): + if type_.varying: + compiled = "BIT VARYING" + if type_.length is not None: + compiled += "(%d)" % type_.length + else: + compiled = "BIT(%d)" % type_.length + return compiled - def visit_UUID(self, type_): + def visit_UUID(self, type_, **kw): return "UUID" - def visit_large_binary(self, type_): - return self.visit_BYTEA(type_) - - def visit_BYTEA(self, type_): + def visit_large_binary(self, type_, **kw): + return self.visit_BYTEA(type_, **kw) + + def visit_BYTEA(self, type_, **kw): return "BYTEA" - def visit_REAL(self, type_): - return "REAL" - - def visit_ARRAY(self, type_): - return self.process(type_.item_type) + '[]' + def visit_ARRAY(self, type_, **kw): + return self.process(type_.item_type) + ('[]' * (type_.dimensions + if type_.dimensions + is not None else 1)) class PGIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = RESERVED_WORDS + def _unquote_identifier(self, value): if value[0] == self.initial_quote: - value = value[1:-1].replace(self.escape_to_quote, self.escape_quote) + value = value[1:-1].\ + replace(self.escape_to_quote, self.escape_quote) return value def format_type(self, type_, use_schema=True): if not type_.name: - raise exc.ArgumentError("Postgresql ENUM type requires a name.") - - name = self.quote(type_.name, type_.quote) - if not self.omit_schema and use_schema and type_.schema is not None: - name = self.quote_schema(type_.schema, type_.quote) + "." + name + raise exc.CompileError("PostgreSQL ENUM type requires a name.") + + name = self.quote(type_.name) + effective_schema = self.schema_for_object(type_) + + if not self.omit_schema and use_schema and \ + effective_schema is not None: + name = self.quote_schema(effective_schema) + "." + name return name - + + class PGInspector(reflection.Inspector): def __init__(self, conn): reflection.Inspector.__init__(self, conn) def get_table_oid(self, table_name, schema=None): - """Return the oid from `table_name` and `schema`.""" + """Return the OID for the given table name.""" - return self.dialect.get_table_oid(self.conn, table_name, schema, + return self.dialect.get_table_oid(self.bind, table_name, schema, info_cache=self.info_cache) + def get_enums(self, schema=None): + """Return a list of ENUM objects. + + Each member is a dictionary containing these fields: + + * name - name of the enum + * schema - the schema name for the enum. + * visible - boolean, whether or not this enum is visible + in the default search path. + * labels - a list of string labels that apply to the enum. + + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to '*' to + indicate load enums for all schemas. + + .. versionadded:: 1.0.0 + + """ + schema = schema or self.default_schema_name + return self.dialect._load_enums(self.bind, schema) + + def get_foreign_table_names(self, schema=None): + """Return a list of FOREIGN TABLE names. + + Behavior is similar to that of :meth:`.Inspector.get_table_names`, + except that the list is limited to those tables tha report a + ``relkind`` value of ``f``. + + .. versionadded:: 1.0.0 + + """ + schema = schema or self.default_schema_name + return self.dialect._get_foreign_table_names(self.bind, schema) + + def get_view_names(self, schema=None, include=('plain', 'materialized')): + """Return all view names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + + :param include: specify which types of views to return. Passed + as a string value (for a single type) or a tuple (for any number + of types). Defaults to ``('plain', 'materialized')``. + + .. versionadded:: 1.1 + + """ + + return self.dialect.get_view_names(self.bind, schema, + info_cache=self.info_cache, + include=include) + + class CreateEnumType(schema._CreateDropBase): - __visit_name__ = "create_enum_type" + __visit_name__ = "create_enum_type" + class DropEnumType(schema._CreateDropBase): - __visit_name__ = "drop_enum_type" + __visit_name__ = "drop_enum_type" + class PGExecutionContext(default.DefaultExecutionContext): - def fire_sequence(self, seq): - if not seq.optional: - return self._execute_scalar(("select nextval('%s')" % \ - self.dialect.identifier_preparer.format_sequence(seq))) - else: - return None + + def fire_sequence(self, seq, type_): + return self._execute_scalar(( + "select nextval('%s')" % + self.dialect.identifier_preparer.format_sequence(seq)), type_) def get_insert_default(self, column): - if column.primary_key: - if (isinstance(column.server_default, schema.DefaultClause) and - column.server_default.arg is not None): + if column.primary_key and \ + column is column.table._autoincrement_column: + if column.server_default and column.server_default.has_argument: # pre-execute passive defaults on primary key columns - return self._execute_scalar("select %s" % column.server_default.arg) + return self._execute_scalar("select %s" % + column.server_default.arg, + column.type) - elif column is column.table._autoincrement_column \ - and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + elif (column.default is None or + (column.default.is_sequence and + column.default.optional)): - # execute the sequence associated with a SERIAL primary key column. - # for non-primary-key SERIAL, the ID just generates server side. - sch = column.table.schema + # execute the sequence associated with a SERIAL primary + # key column. for non-primary-key SERIAL, the ID just + # generates server side. - if sch is not None: - exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) + try: + seq_name = column._postgresql_seq_name + except AttributeError: + tab = column.table.name + col = column.name + tab = tab[0:29 + max(0, (29 - len(col)))] + col = col[0:29 + max(0, (29 - len(tab)))] + name = "%s_%s_seq" % (tab, col) + column._postgresql_seq_name = seq_name = name + + if column.table is not None: + effective_schema = self.connection.schema_for_object( + column.table) else: - exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) + effective_schema = None - return self._execute_scalar(exc) + if effective_schema is not None: + exc = "select nextval('\"%s\".\"%s\"')" % \ + (effective_schema, seq_name) + else: + exc = "select nextval('\"%s\"')" % \ + (seq_name, ) + + return self._execute_scalar(exc, column.type) return super(PGExecutionContext, self).get_insert_default(column) - + + def should_autocommit_text(self, statement): + return AUTOCOMMIT_REGEXP.match(statement) + + class PGDialect(default.DefaultDialect): name = 'postgresql' supports_alter = True max_identifier_length = 63 supports_sane_rowcount = True - + supports_native_enum = True supports_native_boolean = True - + supports_smallserial = True + supports_sequences = True sequences_optional = True preexecute_autoincrement_sequences = True postfetch_lastrowid = False - + supports_default_values = True supports_empty_insert = False + supports_multivalues_insert = True default_paramstyle = 'pyformat' ischema_names = ischema_names colspecs = colspecs - + statement_compiler = PGCompiler ddl_compiler = PGDDLCompiler type_compiler = PGTypeCompiler @@ -604,42 +2037,109 @@ class PGDialect(default.DefaultDialect): inspector = PGInspector isolation_level = None - def __init__(self, isolation_level=None, **kwargs): + construct_arguments = [ + (schema.Index, { + "using": False, + "where": None, + "ops": {}, + "concurrently": False, + "with": {}, + "tablespace": None + }), + (schema.Table, { + "ignore_search_path": False, + "tablespace": None, + "with_oids": None, + "on_commit": None, + "inherits": None + }), + ] + + reflection_options = ('postgresql_ignore_search_path', ) + + _backslash_escapes = True + _supports_create_index_concurrently = True + _supports_drop_index_concurrently = True + + def __init__(self, isolation_level=None, json_serializer=None, + json_deserializer=None, **kwargs): default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level + self._json_deserializer = json_deserializer + self._json_serializer = json_serializer def initialize(self, connection): super(PGDialect, self).initialize(connection) self.implicit_returning = self.server_version_info > (8, 2) and \ - self.__dict__.get('implicit_returning', True) + self.__dict__.get('implicit_returning', True) self.supports_native_enum = self.server_version_info >= (8, 3) if not self.supports_native_enum: self.colspecs = self.colspecs.copy() - del self.colspecs[ENUM] + # pop base Enum type + self.colspecs.pop(sqltypes.Enum, None) + # psycopg2, others may have placed ENUM here as well + self.colspecs.pop(ENUM, None) + + # http://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 + self.supports_smallserial = self.server_version_info >= (9, 2) + + self._backslash_escapes = self.server_version_info < (8, 2) or \ + connection.scalar( + "show standard_conforming_strings" + ) == 'off' + + self._supports_create_index_concurrently = \ + self.server_version_info >= (8, 2) + self._supports_drop_index_concurrently = \ + self.server_version_info >= (9, 2) def on_connect(self): if self.isolation_level is not None: def connect(conn): - cursor = conn.cursor() - cursor.execute("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL %s" - % self.isolation_level) - cursor.execute("COMMIT") - cursor.close() + self.set_isolation_level(conn, self.isolation_level) return connect else: return None - + + _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', + 'READ COMMITTED', 'REPEATABLE READ']) + + def set_isolation_level(self, connection, level): + level = level.replace('_', ' ') + if level not in self._isolation_lookup: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" % + (level, self.name, ", ".join(self._isolation_lookup)) + ) + cursor = connection.cursor() + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION " + "ISOLATION LEVEL %s" % level) + cursor.execute("COMMIT") + cursor.close() + + def get_isolation_level(self, connection): + cursor = connection.cursor() + cursor.execute('show transaction isolation level') + val = cursor.fetchone()[0] + cursor.close() + return val.upper() + def do_begin_twophase(self, connection, xid): self.do_begin(connection.connection) def do_prepare_twophase(self, connection, xid): connection.execute("PREPARE TRANSACTION '%s'" % xid) - def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + def do_rollback_twophase(self, connection, xid, + is_prepared=True, recover=False): if is_prepared: if recover: - #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions - # Must find out a way how to make the dbapi not open a transaction. + # FIXME: ugly hack to get out of transaction + # context when committing recoverable transactions + # Must find out a way how to make the dbapi not + # open a transaction. connection.execute("ROLLBACK") connection.execute("ROLLBACK PREPARED '%s'" % xid) connection.execute("BEGIN") @@ -647,7 +2147,8 @@ class PGDialect(default.DefaultDialect): else: self.do_rollback(connection.connection) - def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + def do_commit_twophase(self, connection, xid, + is_prepared=True, recover=False): if is_prepared: if recover: connection.execute("ROLLBACK") @@ -658,31 +2159,55 @@ class PGDialect(default.DefaultDialect): self.do_commit(connection.connection) def do_recover_twophase(self, connection): - resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) + resultset = connection.execute( + sql.text("SELECT gid FROM pg_prepared_xacts")) return [row[0] for row in resultset] def _get_default_schema_name(self, connection): return connection.scalar("select current_schema()") + def has_schema(self, connection, schema): + query = ("select nspname from pg_namespace " + "where lower(nspname)=:schema") + cursor = connection.execute( + sql.text( + query, + bindparams=[ + sql.bindparam( + 'schema', util.text_type(schema.lower()), + type_=sqltypes.Unicode)] + ) + ) + + return bool(cursor.first()) + def has_table(self, connection, table_name, schema=None): # seems like case gets folded in pg_class... if schema is None: cursor = connection.execute( - sql.text("select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where n.nspname=current_schema() and " - "lower(relname)=:name", + sql.text( + "select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where " + "pg_catalog.pg_table_is_visible(c.oid) " + "and relname=:name", bindparams=[ - sql.bindparam('name', unicode(table_name.lower()), - type_=sqltypes.Unicode)] + sql.bindparam('name', util.text_type(table_name), + type_=sqltypes.Unicode)] ) ) else: cursor = connection.execute( - sql.text("select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where n.nspname=:schema and lower(relname)=:name", + sql.text( + "select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=:schema and " + "relname=:name", bindparams=[ - sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode), - sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)] + sql.bindparam('name', + util.text_type(table_name), + type_=sqltypes.Unicode), + sql.bindparam('schema', + util.text_type(schema), + type_=sqltypes.Unicode)] ) ) return bool(cursor.first()) @@ -690,37 +2215,36 @@ class PGDialect(default.DefaultDialect): def has_sequence(self, connection, sequence_name, schema=None): if schema is None: cursor = connection.execute( - sql.text("SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and n.nspname=current_schema()" - " and lower(relname)=:name", - bindparams=[ - sql.bindparam('name', unicode(sequence_name.lower()), - type_=sqltypes.Unicode) - ] - ) - ) + sql.text( + "SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and " + "n.nspname=current_schema() " + "and relname=:name", + bindparams=[ + sql.bindparam('name', util.text_type(sequence_name), + type_=sqltypes.Unicode) + ] + ) + ) else: cursor = connection.execute( - sql.text("SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and n.nspname=:schema and " - "lower(relname)=:name", - bindparams=[ - sql.bindparam('name', unicode(sequence_name.lower()), - type_=sqltypes.Unicode), - sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode) - ] - ) + sql.text( + "SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and " + "n.nspname=:schema and relname=:name", + bindparams=[ + sql.bindparam('name', util.text_type(sequence_name), + type_=sqltypes.Unicode), + sql.bindparam('schema', + util.text_type(schema), + type_=sqltypes.Unicode) + ] ) + ) return bool(cursor.first()) def has_type(self, connection, type_name, schema=None): - bindparams = [ - sql.bindparam('typname', - unicode(type_name), type_=sqltypes.Unicode), - sql.bindparam('nspname', - unicode(schema), type_=sqltypes.Unicode), - ] if schema is not None: query = """ SELECT EXISTS ( @@ -730,6 +2254,7 @@ class PGDialect(default.DefaultDialect): AND n.nspname = :nspname ) """ + query = sql.text(query) else: query = """ SELECT EXISTS ( @@ -738,14 +2263,28 @@ class PGDialect(default.DefaultDialect): AND pg_type_is_visible(t.oid) ) """ - cursor = connection.execute(sql.text(query, bindparams=bindparams)) + query = sql.text(query) + query = query.bindparams( + sql.bindparam('typname', + util.text_type(type_name), type_=sqltypes.Unicode), + ) + if schema is not None: + query = query.bindparams( + sql.bindparam('nspname', + util.text_type(schema), type_=sqltypes.Unicode), + ) + cursor = connection.execute(query) return bool(cursor.scalar()) def _get_server_version_info(self, connection): v = connection.execute("select version()").scalar() - m = re.match('PostgreSQL (\d+)\.(\d+)(?:\.(\d+))?(?:devel)?', v) + m = re.match( + r'.*(?:PostgreSQL|EnterpriseDB) ' + r'(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel)?', + v) if not m: - raise AssertionError("Could not determine version from string '%s'" % v) + raise AssertionError( + "Could not determine version from string '%s'" % v) return tuple([int(x) for x in m.group(1, 2, 3) if x is not None]) @reflection.cache @@ -767,19 +2306,17 @@ class PGDialect(default.DefaultDialect): FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE (%s) - AND c.relname = :table_name AND c.relkind in ('r','v') + AND c.relname = :table_name AND c.relkind in ('r', 'v', 'm', 'f') """ % schema_where_clause # Since we're binding to unicode, table_name and schema_name must be # unicode. - table_name = unicode(table_name) + table_name = util.text_type(table_name) if schema is not None: - schema = unicode(schema) - s = sql.text(query, bindparams=[ - sql.bindparam('table_name', type_=sqltypes.Unicode), - sql.bindparam('schema', type_=sqltypes.Unicode) - ], - typemap={'oid':sqltypes.Integer} - ) + schema = util.text_type(schema) + s = sql.text(query).bindparams(table_name=sqltypes.Unicode) + s = s.columns(oid=sqltypes.Integer) + if schema: + s = s.bindparams(sql.bindparam('schema', type_=sqltypes.Unicode)) c = connection.execute(s, table_name=table_name, schema=schema) table_oid = c.scalar() if table_oid is None: @@ -788,79 +2325,70 @@ class PGDialect(default.DefaultDialect): @reflection.cache def get_schema_names(self, connection, **kw): - s = """ - SELECT nspname - FROM pg_namespace - ORDER BY nspname - """ - rp = connection.execute(s) - # what about system tables? - # Py3K - #schema_names = [row[0] for row in rp \ - # if not row[0].startswith('pg_')] - # Py2K - schema_names = [row[0].decode(self.encoding) for row in rp \ - if not row[0].startswith('pg_')] - # end Py2K - return schema_names + result = connection.execute( + sql.text("SELECT nspname FROM pg_namespace " + "WHERE nspname NOT LIKE 'pg_%' " + "ORDER BY nspname" + ).columns(nspname=sqltypes.Unicode)) + return [name for name, in result] @reflection.cache def get_table_names(self, connection, schema=None, **kw): - if schema is not None: - current_schema = schema - else: - current_schema = self.default_schema_name - result = connection.execute( - sql.text(u"SELECT relname FROM pg_class c " - "WHERE relkind = 'r' " - "AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) " % - current_schema, - typemap = {'relname':sqltypes.Unicode} - ) - ) - return [row[0] for row in result] - + sql.text("SELECT c.relname FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relkind = 'r'" + ).columns(relname=sqltypes.Unicode), + schema=schema if schema is not None else self.default_schema_name) + return [name for name, in result] @reflection.cache - def get_view_names(self, connection, schema=None, **kw): - if schema is not None: - current_schema = schema - else: - current_schema = self.default_schema_name - s = """ - SELECT relname - FROM pg_class c - WHERE relkind = 'v' - AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) - """ % dict(schema=current_schema) - # Py3K - #view_names = [row[0] for row in connection.execute(s)] - # Py2K - view_names = [row[0].decode(self.encoding) for row in connection.execute(s)] - # end Py2K - return view_names + def _get_foreign_table_names(self, connection, schema=None, **kw): + result = connection.execute( + sql.text("SELECT c.relname FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relkind = 'f'" + ).columns(relname=sqltypes.Unicode), + schema=schema if schema is not None else self.default_schema_name) + return [name for name, in result] + + @reflection.cache + def get_view_names( + self, connection, schema=None, + include=('plain', 'materialized'), **kw): + + include_kind = {'plain': 'v', 'materialized': 'm'} + try: + kinds = [include_kind[i] for i in util.to_list(include)] + except KeyError: + raise ValueError( + "include %r unknown, needs to be a sequence containing " + "one or both of 'plain' and 'materialized'" % (include,)) + if not kinds: + raise ValueError( + "empty include, needs to be a sequence containing " + "one or both of 'plain' and 'materialized'") + + result = connection.execute( + sql.text("SELECT c.relname FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relkind IN (%s)" % + (", ".join("'%s'" % elem for elem in kinds)) + ).columns(relname=sqltypes.Unicode), + schema=schema if schema is not None else self.default_schema_name) + return [name for name, in result] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): - if schema is not None: - current_schema = schema - else: - current_schema = self.default_schema_name - s = """ - SELECT definition FROM pg_views - WHERE schemaname = :schema - AND viewname = :view_name - """ - rp = connection.execute(sql.text(s), - view_name=view_name, schema=current_schema) - if rp: - # Py3K - #view_def = rp.scalar() - # Py2K - view_def = rp.scalar().decode(self.encoding) - # end Py2K - return view_def + view_def = connection.scalar( + sql.text("SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relname = :view_name " + "AND c.relkind IN ('v', 'm')" + ).columns(view_def=sqltypes.Unicode), + schema=schema if schema is not None else self.default_schema_name, + view_name=view_name) + return view_def @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -870,8 +2398,10 @@ class PGDialect(default.DefaultDialect): SQL_COLS = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), - (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d - WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef) + (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid) + FROM pg_catalog.pg_attrdef d + WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum + AND a.atthasdef) AS DEFAULT, a.attnotnull, a.attnum, a.attrelid as table_oid FROM pg_catalog.pg_attribute a @@ -879,283 +2409,581 @@ class PGDialect(default.DefaultDialect): AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum """ - s = sql.text(SQL_COLS, - bindparams=[sql.bindparam('table_oid', type_=sqltypes.Integer)], - typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode} - ) + s = sql.text(SQL_COLS, + bindparams=[ + sql.bindparam('table_oid', type_=sqltypes.Integer)], + typemap={ + 'attname': sqltypes.Unicode, + 'default': sqltypes.Unicode} + ) c = connection.execute(s, table_oid=table_oid) rows = c.fetchall() domains = self._load_domains(connection) - enums = self._load_enums(connection) - + enums = dict( + ( + "%s.%s" % (rec['schema'], rec['name']) + if not rec['visible'] else rec['name'], rec) for rec in + self._load_enums(connection, schema='*') + ) + # format columns columns = [] for name, format_type, default, notnull, attnum, table_oid in rows: - ## strip (5) from character varying(5), timestamp(5) with time zone, etc - attype = re.sub(r'\([\d,]+\)', '', format_type) - - # strip '[]' from integer[], etc. - attype = re.sub(r'\[\]', '', attype) - - nullable = not notnull - is_array = format_type.endswith('[]') - charlen = re.search('\(([\d,]+)\)', format_type) - if charlen: - charlen = charlen.group(1) - kwargs = {} - - if attype == 'numeric': - if charlen: - prec, scale = charlen.split(',') - args = (int(prec), int(scale)) - else: - args = () - elif attype == 'double precision': - args = (53, ) - elif attype == 'integer': - args = (32, 0) - elif attype in ('timestamp with time zone', 'time with time zone'): - kwargs['timezone'] = True - if charlen: - kwargs['precision'] = int(charlen) - args = () - elif attype in ('timestamp without time zone', 'time without time zone', 'time'): - kwargs['timezone'] = False - if charlen: - kwargs['precision'] = int(charlen) - args = () - elif attype in ('interval','interval year to month','interval day to second'): - if charlen: - kwargs['precision'] = int(charlen) - args = () - elif charlen: - args = (int(charlen),) - else: - args = () - - if attype in self.ischema_names: - coltype = self.ischema_names[attype] - elif attype in enums: - enum = enums[attype] - coltype = ENUM - if "." in attype: - kwargs['schema'], kwargs['name'] = attype.split('.') - else: - kwargs['name'] = attype - args = tuple(enum['labels']) - elif attype in domains: - domain = domains[attype] - if domain['attype'] in self.ischema_names: - # A table can't override whether the domain is nullable. - nullable = domain['nullable'] - if domain['default'] and not default: - # It can, however, override the default value, but can't set it to null. - default = domain['default'] - coltype = self.ischema_names[domain['attype']] - else: - coltype = None - - if coltype: - coltype = coltype(*args, **kwargs) - if is_array: - coltype = ARRAY(coltype) - else: - util.warn("Did not recognize type '%s' of column '%s'" % - (attype, name)) - coltype = sqltypes.NULLTYPE - # adjust the default value - autoincrement = False - if default is not None: - match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) - if match is not None: - autoincrement = True - # the default is related to a Sequence - sch = schema - if '.' not in match.group(2) and sch is not None: - # unconditionally quote the schema name. this could - # later be enhanced to obey quoting rules / "quote schema" - default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3) - - column_info = dict(name=name, type=coltype, nullable=nullable, - default=default, autoincrement=autoincrement) + column_info = self._get_column_info( + name, format_type, default, notnull, domains, enums, schema) columns.append(column_info) return columns - @reflection.cache - def get_primary_keys(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) - PK_SQL = """ - SELECT attname FROM pg_attribute - WHERE attrelid = ( - SELECT indexrelid FROM pg_index i - WHERE i.indrelid = :table_oid - AND i.indisprimary = 't') - ORDER BY attnum - """ - t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode}) - c = connection.execute(t, table_oid=table_oid) - primary_keys = [r[0] for r in c.fetchall()] - return primary_keys + def _get_column_info(self, name, format_type, default, + notnull, domains, enums, schema): + # strip (*) from character varying(5), timestamp(5) + # with time zone, geometry(POLYGON), etc. + attype = re.sub(r'\(.*\)', '', format_type) + + # strip '[]' from integer[], etc. + attype = attype.replace('[]', '') + + nullable = not notnull + is_array = format_type.endswith('[]') + charlen = re.search(r'\(([\d,]+)\)', format_type) + if charlen: + charlen = charlen.group(1) + args = re.search(r'\((.*)\)', format_type) + if args and args.group(1): + args = tuple(re.split(r'\s*,\s*', args.group(1))) + else: + args = () + kwargs = {} + + if attype == 'numeric': + if charlen: + prec, scale = charlen.split(',') + args = (int(prec), int(scale)) + else: + args = () + elif attype == 'double precision': + args = (53, ) + elif attype == 'integer': + args = () + elif attype in ('timestamp with time zone', + 'time with time zone'): + kwargs['timezone'] = True + if charlen: + kwargs['precision'] = int(charlen) + args = () + elif attype in ('timestamp without time zone', + 'time without time zone', 'time'): + kwargs['timezone'] = False + if charlen: + kwargs['precision'] = int(charlen) + args = () + elif attype == 'bit varying': + kwargs['varying'] = True + if charlen: + args = (int(charlen),) + else: + args = () + elif attype in ('interval', 'interval year to month', + 'interval day to second'): + if charlen: + kwargs['precision'] = int(charlen) + args = () + elif charlen: + args = (int(charlen),) + + while True: + if attype in self.ischema_names: + coltype = self.ischema_names[attype] + break + elif attype in enums: + enum = enums[attype] + coltype = ENUM + kwargs['name'] = enum['name'] + if not enum['visible']: + kwargs['schema'] = enum['schema'] + args = tuple(enum['labels']) + break + elif attype in domains: + domain = domains[attype] + attype = domain['attype'] + # A table can't override whether the domain is nullable. + nullable = domain['nullable'] + if domain['default'] and not default: + # It can, however, override the default + # value, but can't set it to null. + default = domain['default'] + continue + else: + coltype = None + break + + if coltype: + coltype = coltype(*args, **kwargs) + if is_array: + coltype = self.ischema_names['_array'](coltype) + else: + util.warn("Did not recognize type '%s' of column '%s'" % + (attype, name)) + coltype = sqltypes.NULLTYPE + # adjust the default value + autoincrement = False + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + if issubclass(coltype._type_affinity, sqltypes.Integer): + autoincrement = True + # the default is related to a Sequence + sch = schema + if '.' not in match.group(2) and sch is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / + # "quote schema" + default = match.group(1) + \ + ('"%s"' % sch) + '.' + \ + match.group(2) + match.group(3) + + column_info = dict(name=name, type=coltype, nullable=nullable, + default=default, autoincrement=autoincrement) + return column_info @reflection.cache - def get_foreign_keys(self, connection, table_name, schema=None, **kw): + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + + if self.server_version_info < (8, 4): + PK_SQL = """ + SELECT a.attname + FROM + pg_class t + join pg_index ix on t.oid = ix.indrelid + join pg_attribute a + on t.oid=a.attrelid AND %s + WHERE + t.oid = :table_oid and ix.indisprimary = 't' + ORDER BY a.attnum + """ % self._pg_index_any("a.attnum", "ix.indkey") + + else: + # unnest() and generate_subscripts() both introduced in + # version 8.4 + PK_SQL = """ + SELECT a.attname + FROM pg_attribute a JOIN ( + SELECT unnest(ix.indkey) attnum, + generate_subscripts(ix.indkey, 1) ord + FROM pg_index ix + WHERE ix.indrelid = :table_oid AND ix.indisprimary + ) k ON a.attnum=k.attnum + WHERE a.attrelid = :table_oid + ORDER BY k.ord + """ + t = sql.text(PK_SQL, typemap={'attname': sqltypes.Unicode}) + c = connection.execute(t, table_oid=table_oid) + cols = [r[0] for r in c.fetchall()] + + PK_CONS_SQL = """ + SELECT conname + FROM pg_catalog.pg_constraint r + WHERE r.conrelid = :table_oid AND r.contype = 'p' + ORDER BY 1 + """ + t = sql.text(PK_CONS_SQL, typemap={'conname': sqltypes.Unicode}) + c = connection.execute(t, table_oid=table_oid) + name = c.scalar() + + return {'constrained_columns': cols, 'name': name} + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, + postgresql_ignore_search_path=False, **kw): preparer = self.identifier_preparer table_oid = self.get_table_oid(connection, table_name, schema, info_cache=kw.get('info_cache')) + FK_SQL = """ - SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef - FROM pg_catalog.pg_constraint r - WHERE r.conrelid = :table AND r.contype = 'f' + SELECT r.conname, + pg_catalog.pg_get_constraintdef(r.oid, true) as condef, + n.nspname as conschema + FROM pg_catalog.pg_constraint r, + pg_namespace n, + pg_class c + + WHERE r.conrelid = :table AND + r.contype = 'f' AND + c.oid = confrelid AND + n.oid = c.relnamespace ORDER BY 1 """ + # http://www.postgresql.org/docs/9.0/static/sql-createtable.html + FK_REGEX = re.compile( + r'FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)' + r'[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?' + r'[\s]?(ON UPDATE ' + r'(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?' + r'[\s]?(ON DELETE ' + r'(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?' + r'[\s]?(DEFERRABLE|NOT DEFERRABLE)?' + r'[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?' + ) - t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode}) + t = sql.text(FK_SQL, typemap={ + 'conname': sqltypes.Unicode, + 'condef': sqltypes.Unicode}) c = connection.execute(t, table=table_oid) fkeys = [] - for conname, condef in c.fetchall(): - m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups() - (constrained_columns, referred_schema, referred_table, referred_columns) = m - constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] - if referred_schema: - referred_schema = preparer._unquote_identifier(referred_schema) - elif schema is not None and schema == self.default_schema_name: - # no schema (i.e. its the default schema), and the table we're - # reflecting has the default schema explicit, then use that. - # i.e. try to use the user's conventions + for conname, condef, conschema in c.fetchall(): + m = re.search(FK_REGEX, condef).groups() + + constrained_columns, referred_schema, \ + referred_table, referred_columns, \ + _, match, _, onupdate, _, ondelete, \ + deferrable, _, initially = m + + if deferrable is not None: + deferrable = True if deferrable == 'DEFERRABLE' else False + constrained_columns = [preparer._unquote_identifier(x) + for x in re.split( + r'\s*,\s*', constrained_columns)] + + if postgresql_ignore_search_path: + # when ignoring search path, we use the actual schema + # provided it isn't the "default" schema + if conschema != self.default_schema_name: + referred_schema = conschema + else: + referred_schema = schema + elif referred_schema: + # referred_schema is the schema that we regexp'ed from + # pg_get_constraintdef(). If the schema is in the search + # path, pg_get_constraintdef() will give us None. + referred_schema = \ + preparer._unquote_identifier(referred_schema) + elif schema is not None and schema == conschema: + # If the actual schema matches the schema of the table + # we're reflecting, then we will use that. referred_schema = schema + referred_table = preparer._unquote_identifier(referred_table) - referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)] + referred_columns = [preparer._unquote_identifier(x) + for x in + re.split(r'\s*,\s', referred_columns)] fkey_d = { - 'name' : conname, - 'constrained_columns' : constrained_columns, - 'referred_schema' : referred_schema, - 'referred_table' : referred_table, - 'referred_columns' : referred_columns + 'name': conname, + 'constrained_columns': constrained_columns, + 'referred_schema': referred_schema, + 'referred_table': referred_table, + 'referred_columns': referred_columns, + 'options': { + 'onupdate': onupdate, + 'ondelete': ondelete, + 'deferrable': deferrable, + 'initially': initially, + 'match': match + } } fkeys.append(fkey_d) return fkeys + def _pg_index_any(self, col, compare_to): + if self.server_version_info < (8, 1): + # http://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us + # "In CVS tip you could replace this with "attnum = ANY (indkey)". + # Unfortunately, most array support doesn't work on int2vector in + # pre-8.1 releases, so I think you're kinda stuck with the above + # for now. + # regards, tom lane" + return "(%s)" % " OR ".join( + "%s[%d] = %s" % (compare_to, ind, col) + for ind in range(0, 10) + ) + else: + return "%s = ANY(%s)" % (col, compare_to) + @reflection.cache def get_indexes(self, connection, table_name, schema, **kw): table_oid = self.get_table_oid(connection, table_name, schema, info_cache=kw.get('info_cache')) - IDX_SQL = """ - SELECT c.relname, i.indisunique, i.indexprs, i.indpred, - a.attname - FROM pg_index i, pg_class c, pg_attribute a - WHERE i.indrelid = :table_oid AND i.indexrelid = c.oid - AND a.attrelid = i.indexrelid AND i.indisprimary = 'f' - ORDER BY c.relname, a.attnum - """ - t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode}) + + # cast indkey as varchar since it's an int2vector, + # returned as a list by some drivers such as pypostgresql + + if self.server_version_info < (8, 5): + IDX_SQL = """ + SELECT + i.relname as relname, + ix.indisunique, ix.indexprs, ix.indpred, + a.attname, a.attnum, NULL, ix.indkey%s, + %s, am.amname + FROM + pg_class t + join pg_index ix on t.oid = ix.indrelid + join pg_class i on i.oid = ix.indexrelid + left outer join + pg_attribute a + on t.oid = a.attrelid and %s + left outer join + pg_am am + on i.relam = am.oid + WHERE + t.relkind IN ('r', 'v', 'f', 'm') + and t.oid = :table_oid + and ix.indisprimary = 'f' + ORDER BY + t.relname, + i.relname + """ % ( + # version 8.3 here was based on observing the + # cast does not work in PG 8.2.4, does work in 8.3.0. + # nothing in PG changelogs regarding this. + "::varchar" if self.server_version_info >= (8, 3) else "", + "i.reloptions" if self.server_version_info >= (8, 2) + else "NULL", + self._pg_index_any("a.attnum", "ix.indkey") + ) + else: + IDX_SQL = """ + SELECT + i.relname as relname, + ix.indisunique, ix.indexprs, ix.indpred, + a.attname, a.attnum, c.conrelid, ix.indkey::varchar, + i.reloptions, am.amname + FROM + pg_class t + join pg_index ix on t.oid = ix.indrelid + join pg_class i on i.oid = ix.indexrelid + left outer join + pg_attribute a + on t.oid = a.attrelid and a.attnum = ANY(ix.indkey) + left outer join + pg_constraint c + on (ix.indrelid = c.conrelid and + ix.indexrelid = c.conindid and + c.contype in ('p', 'u', 'x')) + left outer join + pg_am am + on i.relam = am.oid + WHERE + t.relkind IN ('r', 'v', 'f', 'm') + and t.oid = :table_oid + and ix.indisprimary = 'f' + ORDER BY + t.relname, + i.relname + """ + + t = sql.text(IDX_SQL, typemap={ + 'relname': sqltypes.Unicode, + 'attname': sqltypes.Unicode}) c = connection.execute(t, table_oid=table_oid) - index_names = {} - indexes = [] + + indexes = defaultdict(lambda: defaultdict(dict)) + sv_idx_name = None for row in c.fetchall(): - idx_name, unique, expr, prd, col = row + (idx_name, unique, expr, prd, col, + col_num, conrelid, idx_key, options, amname) = row + if expr: if idx_name != sv_idx_name: util.warn( - "Skipped unsupported reflection of expression-based index %s" - % idx_name) + "Skipped unsupported reflection of " + "expression-based index %s" + % idx_name) sv_idx_name = idx_name continue + if prd and not idx_name == sv_idx_name: util.warn( - "Predicate of partial index %s ignored during reflection" - % idx_name) + "Predicate of partial index %s ignored during reflection" + % idx_name) sv_idx_name = idx_name - if idx_name in index_names: - index_d = index_names[idx_name] - else: - index_d = {'column_names':[]} - indexes.append(index_d) - index_names[idx_name] = index_d - index_d['name'] = idx_name - index_d['column_names'].append(col) - index_d['unique'] = unique - return indexes - def _load_enums(self, connection): + has_idx = idx_name in indexes + index = indexes[idx_name] + if col is not None: + index['cols'][col_num] = col + if not has_idx: + index['key'] = [int(k.strip()) for k in idx_key.split()] + index['unique'] = unique + if conrelid is not None: + index['duplicates_constraint'] = idx_name + if options: + index['options'] = dict( + [option.split("=") for option in options]) + + # it *might* be nice to include that this is 'btree' in the + # reflection info. But we don't want an Index object + # to have a ``postgresql_using`` in it that is just the + # default, so for the moment leaving this out. + if amname and amname != 'btree': + index['amname'] = amname + + result = [] + for name, idx in indexes.items(): + entry = { + 'name': name, + 'unique': idx['unique'], + 'column_names': [idx['cols'][i] for i in idx['key']] + } + if 'duplicates_constraint' in idx: + entry['duplicates_constraint'] = idx['duplicates_constraint'] + if 'options' in idx: + entry.setdefault( + 'dialect_options', {})["postgresql_with"] = idx['options'] + if 'amname' in idx: + entry.setdefault( + 'dialect_options', {})["postgresql_using"] = idx['amname'] + result.append(entry) + return result + + @reflection.cache + def get_unique_constraints(self, connection, table_name, + schema=None, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + + UNIQUE_SQL = """ + SELECT + cons.conname as name, + cons.conkey as key, + a.attnum as col_num, + a.attname as col_name + FROM + pg_catalog.pg_constraint cons + join pg_attribute a + on cons.conrelid = a.attrelid AND + a.attnum = ANY(cons.conkey) + WHERE + cons.conrelid = :table_oid AND + cons.contype = 'u' + """ + + t = sql.text(UNIQUE_SQL, typemap={'col_name': sqltypes.Unicode}) + c = connection.execute(t, table_oid=table_oid) + + uniques = defaultdict(lambda: defaultdict(dict)) + for row in c.fetchall(): + uc = uniques[row.name] + uc["key"] = row.key + uc["cols"][row.col_num] = row.col_name + + return [ + {'name': name, + 'column_names': [uc["cols"][i] for i in uc["key"]]} + for name, uc in uniques.items() + ] + + @reflection.cache + def get_check_constraints( + self, connection, table_name, schema=None, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + + CHECK_SQL = """ + SELECT + cons.conname as name, + cons.consrc as src + FROM + pg_catalog.pg_constraint cons + WHERE + cons.conrelid = :table_oid AND + cons.contype = 'c' + """ + + c = connection.execute(sql.text(CHECK_SQL), table_oid=table_oid) + + return [ + {'name': name, + 'sqltext': src[1:-1]} + for name, src in c.fetchall() + ] + + def _load_enums(self, connection, schema=None): + schema = schema or self.default_schema_name if not self.supports_native_enum: return {} - ## Load data types for enums: + # Load data types for enums: SQL_ENUMS = """ SELECT t.typname as "name", - -- t.typdefault as "default", -- no enum defaults in 8.4 at least - pg_catalog.pg_type_is_visible(t.oid) as "visible", - n.nspname as "schema", - e.enumlabel as "label" + -- no enum defaults in 8.4 at least + -- t.typdefault as "default", + pg_catalog.pg_type_is_visible(t.oid) as "visible", + n.nspname as "schema", + e.enumlabel as "label" FROM pg_catalog.pg_type t LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace - LEFT JOIN pg_catalog.pg_constraint r ON t.oid = r.contypid LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid WHERE t.typtype = 'e' - ORDER BY "name", e.oid -- e.oid gives us label order """ - s = sql.text(SQL_ENUMS, typemap={'attname':sqltypes.Unicode, 'label':sqltypes.Unicode}) + if schema != '*': + SQL_ENUMS += "AND n.nspname = :schema " + + # e.oid gives us label order within an enum + SQL_ENUMS += 'ORDER BY "schema", "name", e.oid' + + s = sql.text(SQL_ENUMS, typemap={ + 'attname': sqltypes.Unicode, + 'label': sqltypes.Unicode}) + + if schema != '*': + s = s.bindparams(schema=schema) + c = connection.execute(s) - enums = {} + enums = [] + enum_by_name = {} for enum in c.fetchall(): - if enum['visible']: - # 'visible' just means whether or not the enum is in a - # schema that's on the search path -- or not overriden by - # a schema with higher presedence. If it's not visible, - # it will be prefixed with the schema-name when it's used. - name = enum['name'] + key = (enum['schema'], enum['name']) + if key in enum_by_name: + enum_by_name[key]['labels'].append(enum['label']) else: - name = "%s.%s" % (enum['schema'], enum['name']) - - if name in enums: - enums[name]['labels'].append(enum['label']) - else: - enums[name] = { - 'labels': [enum['label']], - } + enum_by_name[key] = enum_rec = { + 'name': enum['name'], + 'schema': enum['schema'], + 'visible': enum['visible'], + 'labels': [enum['label']], + } + enums.append(enum_rec) return enums def _load_domains(self, connection): - ## Load data types for domains: + # Load data types for domains: SQL_DOMAINS = """ SELECT t.typname as "name", - pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", - not t.typnotnull as "nullable", - t.typdefault as "default", - pg_catalog.pg_type_is_visible(t.oid) as "visible", - n.nspname as "schema" + pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", + not t.typnotnull as "nullable", + t.typdefault as "default", + pg_catalog.pg_type_is_visible(t.oid) as "visible", + n.nspname as "schema" FROM pg_catalog.pg_type t - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace - LEFT JOIN pg_catalog.pg_constraint r ON t.oid = r.contypid + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace WHERE t.typtype = 'd' """ - s = sql.text(SQL_DOMAINS, typemap={'attname':sqltypes.Unicode}) + s = sql.text(SQL_DOMAINS, typemap={'attname': sqltypes.Unicode}) c = connection.execute(s) domains = {} for domain in c.fetchall(): - ## strip (30) from character varying(30) - attype = re.search('([^\(]+)', domain['attype']).group(1) + # strip (30) from character varying(30) + attype = re.search(r'([^\(]+)', domain['attype']).group(1) if domain['visible']: # 'visible' just means whether or not the domain is in a - # schema that's on the search path -- or not overriden by - # a schema with higher presedence. If it's not visible, + # schema that's on the search path -- or not overridden by + # a schema with higher precedence. If it's not visible, # it will be prefixed with the schema-name when it's used. name = domain['name'] else: name = "%s.%s" % (domain['schema'], domain['name']) domains[name] = { - 'attype':attype, - 'nullable': domain['nullable'], - 'default': domain['default'] - } + 'attype': attype, + 'nullable': domain['nullable'], + 'default': domain['default'] + } return domains - diff --git a/sqlalchemy/dialects/postgresql/pg8000.py b/sqlalchemy/dialects/postgresql/pg8000.py index a620daa..8c019a2 100644 --- a/sqlalchemy/dialects/postgresql/pg8000.py +++ b/sqlalchemy/dialects/postgresql/pg8000.py @@ -1,63 +1,130 @@ -"""Support for the PostgreSQL database via the pg8000 driver. +# postgresql/pg8000.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php -Connecting ----------- +""" +.. dialect:: postgresql+pg8000 + :name: pg8000 + :dbapi: pg8000 + :connectstring: \ +postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pythonhosted.org/pg8000/ -URLs are of the form -`postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]`. + +.. _pg8000_unicode: Unicode ------- -pg8000 requires that the postgresql client encoding be configured in the postgresql.conf file -in order to use encodings other than ascii. Set this value to the same value as -the "encoding" parameter on create_engine(), usually "utf-8". +pg8000 will encode / decode string values between it and the server using the +PostgreSQL ``client_encoding`` parameter; by default this is the value in +the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``. +Typically, this can be changed to ``utf-8``, as a more useful default:: -Interval --------- + #client_encoding = sql_ascii # actually, defaults to database + # encoding + client_encoding = utf8 + +The ``client_encoding`` can be overridden for a session by executing the SQL: + +SET CLIENT_ENCODING TO 'utf8'; + +SQLAlchemy will execute this SQL on all new connections based on the value +passed to :func:`.create_engine` using the ``client_encoding`` parameter:: + + engine = create_engine( + "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8') + + +.. _pg8000_isolation_level: + +pg8000 Transaction Isolation Level +------------------------------------- + +The pg8000 dialect offers the same isolation level settings as that +of the :ref:`psycopg2 ` dialect: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +.. versionadded:: 0.9.5 support for AUTOCOMMIT isolation level when using + pg8000. + +.. seealso:: + + :ref:`postgresql_isolation_level` + + :ref:`psycopg2_isolation_level` -Passing data from/to the Interval type is not supported as of yet. """ +from ... import util, exc import decimal +from ... import processors +from ... import types as sqltypes +from .base import ( + PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext, + _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES) +import re +from sqlalchemy.dialects.postgresql.json import JSON -from sqlalchemy.engine import default -from sqlalchemy import util, exc -from sqlalchemy import processors -from sqlalchemy import types as sqltypes -from sqlalchemy.dialects.postgresql.base import PGDialect, \ - PGCompiler, PGIdentifierPreparer, PGExecutionContext class _PGNumeric(sqltypes.Numeric): def result_processor(self, dialect, coltype): if self.asdecimal: - if coltype in (700, 701): - return processors.to_decimal_processor_factory(decimal.Decimal) - elif coltype == 1700: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: # pg8000 returns Decimal natively for 1700 return None else: - raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype) else: - if coltype in (700, 701): + if coltype in _FLOAT_TYPES: # pg8000 returns float natively for 701 return None - elif coltype == 1700: + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: return processors.to_float else: - raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype) + + +class _PGNumericNoBind(_PGNumeric): + def bind_processor(self, dialect): + return None + + +class _PGJSON(JSON): + + def result_processor(self, dialect, coltype): + if dialect._dbapi_version > (1, 10, 1): + return None # Has native JSON + else: + return super(_PGJSON, self).result_processor(dialect, coltype) + class PGExecutionContext_pg8000(PGExecutionContext): pass class PGCompiler_pg8000(PGCompiler): - def visit_mod(self, binary, **kw): - return self.process(binary.left) + " %% " + self.process(binary.right) + def visit_mod_binary(self, binary, operator, **kw): + return self.process(binary.left, **kw) + " %% " + \ + self.process(binary.right, **kw) def post_process_text(self, text): if '%%' in text: - util.warn("The SQLAlchemy postgresql dialect now automatically escapes '%' in text() " + util.warn("The SQLAlchemy postgresql dialect " + "now automatically escapes '%' in text() " "expressions to '%%'.") return text.replace('%', '%%') @@ -67,30 +134,52 @@ class PGIdentifierPreparer_pg8000(PGIdentifierPreparer): value = value.replace(self.escape_quote, self.escape_to_quote) return value.replace('%', '%%') - + class PGDialect_pg8000(PGDialect): driver = 'pg8000' supports_unicode_statements = True - + supports_unicode_binds = True - + default_paramstyle = 'format' - supports_sane_multi_rowcount = False + supports_sane_multi_rowcount = True execution_ctx_cls = PGExecutionContext_pg8000 statement_compiler = PGCompiler_pg8000 preparer = PGIdentifierPreparer_pg8000 - + description_encoding = 'use_encoding' + colspecs = util.update_copy( PGDialect.colspecs, { - sqltypes.Numeric : _PGNumeric, + sqltypes.Numeric: _PGNumericNoBind, + sqltypes.Float: _PGNumeric, + JSON: _PGJSON, + sqltypes.JSON: _PGJSON } ) - + + def __init__(self, client_encoding=None, **kwargs): + PGDialect.__init__(self, **kwargs) + self.client_encoding = client_encoding + + def initialize(self, connection): + self.supports_sane_multi_rowcount = self._dbapi_version >= (1, 9, 14) + super(PGDialect_pg8000, self).initialize(connection) + + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, '__version__'): + return tuple( + [ + int(x) for x in re.findall( + r'(\d+)(?:[-\.]?|$)', self.dbapi.__version__)]) + else: + return (99, 99, 99) + @classmethod def dbapi(cls): - return __import__('pg8000').dbapi + return __import__('pg8000') def create_connect_args(self, url): opts = url.translate_connect_args(username='user') @@ -99,7 +188,78 @@ class PGDialect_pg8000(PGDialect): opts.update(url.query) return ([], opts) - def is_disconnect(self, e): + def is_disconnect(self, e, connection, cursor): return "connection is closed" in str(e) + def set_isolation_level(self, connection, level): + level = level.replace('_', ' ') + + # adjust for ConnectionFairy possibly being present + if hasattr(connection, 'connection'): + connection = connection.connection + + if level == 'AUTOCOMMIT': + connection.autocommit = True + elif level in self._isolation_lookup: + connection.autocommit = False + cursor = connection.cursor() + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION " + "ISOLATION LEVEL %s" % level) + cursor.execute("COMMIT") + cursor.close() + else: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s or AUTOCOMMIT" % + (level, self.name, ", ".join(self._isolation_lookup)) + ) + + def set_client_encoding(self, connection, client_encoding): + # adjust for ConnectionFairy possibly being present + if hasattr(connection, 'connection'): + connection = connection.connection + + cursor = connection.cursor() + cursor.execute("SET CLIENT_ENCODING TO '" + client_encoding + "'") + cursor.execute("COMMIT") + cursor.close() + + def do_begin_twophase(self, connection, xid): + connection.connection.tpc_begin((0, xid, '')) + + def do_prepare_twophase(self, connection, xid): + connection.connection.tpc_prepare() + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False): + connection.connection.tpc_rollback((0, xid, '')) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False): + connection.connection.tpc_commit((0, xid, '')) + + def do_recover_twophase(self, connection): + return [row[1] for row in connection.connection.tpc_recover()] + + def on_connect(self): + fns = [] + if self.client_encoding is not None: + def on_connect(conn): + self.set_client_encoding(conn, self.client_encoding) + fns.append(on_connect) + + if self.isolation_level is not None: + def on_connect(conn): + self.set_isolation_level(conn, self.isolation_level) + fns.append(on_connect) + + if len(fns) > 0: + def on_connect(conn): + for fn in fns: + fn(conn) + return on_connect + else: + return None + dialect = PGDialect_pg8000 diff --git a/sqlalchemy/dialects/postgresql/psycopg2.py b/sqlalchemy/dialects/postgresql/psycopg2.py index f21c9a5..5032814 100644 --- a/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/sqlalchemy/dialects/postgresql/psycopg2.py @@ -1,74 +1,335 @@ -"""Support for the PostgreSQL database via the psycopg2 driver. +# postgresql/psycopg2.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php -Driver ------- +""" +.. dialect:: postgresql+psycopg2 + :name: psycopg2 + :dbapi: psycopg2 + :connectstring: postgresql+psycopg2://user:password@host:port/dbname\ +[?key=value&key=value...] + :url: http://pypi.python.org/pypi/psycopg2/ -The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ . -The dialect has several behaviors which are specifically tailored towards compatibility -with this module. +psycopg2 Connect Arguments +----------------------------------- -Note that psycopg1 is **not** supported. +psycopg2-specific keyword arguments which are accepted by +:func:`.create_engine()` are: -Connecting ----------- +* ``server_side_cursors``: Enable the usage of "server side cursors" for SQL + statements which support this feature. What this essentially means from a + psycopg2 point of view is that the cursor is created using a name, e.g. + ``connection.cursor('some name')``, which has the effect that result rows + are not immediately pre-fetched and buffered after statement execution, but + are instead left on the server and only retrieved as needed. SQLAlchemy's + :class:`~sqlalchemy.engine.ResultProxy` uses special row-buffering + behavior when this feature is enabled, such that groups of 100 rows at a + time are fetched over the wire to reduce conversational overhead. + Note that the :paramref:`.Connection.execution_options.stream_results` + execution option is a more targeted + way of enabling this mode on a per-execution basis. +* ``use_native_unicode``: Enable the usage of Psycopg2 "native unicode" mode + per connection. True by default. -URLs are of the form `postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]`. + .. seealso:: -psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are: + :ref:`psycopg2_disable_native_unicode` + +* ``isolation_level``: This option, available for all PostgreSQL dialects, + includes the ``AUTOCOMMIT`` isolation level when using the psycopg2 + dialect. + + .. seealso:: + + :ref:`psycopg2_isolation_level` + +* ``client_encoding``: sets the client encoding in a libpq-agnostic way, + using psycopg2's ``set_client_encoding()`` method. + + .. seealso:: + + :ref:`psycopg2_unicode` + +Unix Domain Connections +------------------------ + +psycopg2 supports connecting via Unix domain connections. When the ``host`` +portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2, +which specifies Unix-domain communication rather than TCP/IP communication:: + + create_engine("postgresql+psycopg2://user:password@/dbname") + +By default, the socket file used is to connect to a Unix-domain socket +in ``/tmp``, or whatever socket directory was specified when PostgreSQL +was built. This value can be overridden by passing a pathname to psycopg2, +using ``host`` as an additional keyword argument:: + + create_engine("postgresql+psycopg2://user:password@/dbname?\ +host=/var/lib/postgresql") + +See also: + +`PQconnectdbParams `_ + +.. _psycopg2_execution_options: + +Per-Statement/Connection Execution Options +------------------------------------------- + +The following DBAPI-specific options are respected when used with +:meth:`.Connection.execution_options`, :meth:`.Executable.execution_options`, +:meth:`.Query.execution_options`, in addition to those not specific to DBAPIs: + +* ``isolation_level`` - Set the transaction isolation level for the lifespan of a + :class:`.Connection` (can only be set on a connection, not a statement + or query). See :ref:`psycopg2_isolation_level`. + +* ``stream_results`` - Enable or disable usage of psycopg2 server side cursors - + this feature makes use of "named" cursors in combination with special + result handling methods so that result rows are not fully buffered. + If ``None`` or not set, the ``server_side_cursors`` option of the + :class:`.Engine` is used. + +* ``max_row_buffer`` - when using ``stream_results``, an integer value that + specifies the maximum number of rows to buffer at a time. This is + interpreted by the :class:`.BufferedRowResultProxy`, and if omitted the + buffer will grow to ultimately store 1000 rows at a time. + + .. versionadded:: 1.0.6 + +.. _psycopg2_unicode: + +Unicode with Psycopg2 +---------------------- + +By default, the psycopg2 driver uses the ``psycopg2.extensions.UNICODE`` +extension, such that the DBAPI receives and returns all strings as Python +Unicode objects directly - SQLAlchemy passes these values through without +change. Psycopg2 here will encode/decode string values based on the +current "client encoding" setting; by default this is the value in +the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``. +Typically, this can be changed to ``utf8``, as a more useful default:: + + # postgresql.conf file + + # client_encoding = sql_ascii # actually, defaults to database + # encoding + client_encoding = utf8 + +A second way to affect the client encoding is to set it within Psycopg2 +locally. SQLAlchemy will call psycopg2's +:meth:`psycopg2:connection.set_client_encoding` method +on all new connections based on the value passed to +:func:`.create_engine` using the ``client_encoding`` parameter:: + + # set_client_encoding() setting; + # works for *all* PostgreSQL versions + engine = create_engine("postgresql://user:pass@host/dbname", + client_encoding='utf8') + +This overrides the encoding specified in the PostgreSQL client configuration. +When using the parameter in this way, the psycopg2 driver emits +``SET client_encoding TO 'utf8'`` on the connection explicitly, and works +in all PostgreSQL versions. + +Note that the ``client_encoding`` setting as passed to :func:`.create_engine` +is **not the same** as the more recently added ``client_encoding`` parameter +now supported by libpq directly. This is enabled when ``client_encoding`` +is passed directly to ``psycopg2.connect()``, and from SQLAlchemy is passed +using the :paramref:`.create_engine.connect_args` parameter:: + + # libpq direct parameter setting; + # only works for PostgreSQL **9.1 and above** + engine = create_engine("postgresql://user:pass@host/dbname", + connect_args={'client_encoding': 'utf8'}) + + # using the query string is equivalent + engine = create_engine("postgresql://user:pass@host/dbname?client_encoding=utf8") + +The above parameter was only added to libpq as of version 9.1 of PostgreSQL, +so using the previous method is better for cross-version support. + +.. _psycopg2_disable_native_unicode: + +Disabling Native Unicode +^^^^^^^^^^^^^^^^^^^^^^^^ + +SQLAlchemy can also be instructed to skip the usage of the psycopg2 +``UNICODE`` extension and to instead utilize its own unicode encode/decode +services, which are normally reserved only for those DBAPIs that don't +fully support unicode directly. Passing ``use_native_unicode=False`` to +:func:`.create_engine` will disable usage of ``psycopg2.extensions.UNICODE``. +SQLAlchemy will instead encode data itself into Python bytestrings on the way +in and coerce from bytes on the way back, +using the value of the :func:`.create_engine` ``encoding`` parameter, which +defaults to ``utf-8``. +SQLAlchemy's own unicode encode/decode functionality is steadily becoming +obsolete as most DBAPIs now support unicode fully. + +Bound Parameter Styles +---------------------- + +The default parameter style for the psycopg2 dialect is "pyformat", where +SQL is rendered using ``%(paramname)s`` style. This format has the limitation +that it does not accommodate the unusual case of parameter names that +actually contain percent or parenthesis symbols; as SQLAlchemy in many cases +generates bound parameter names based on the name of a column, the presence +of these characters in a column name can lead to problems. + +There are two solutions to the issue of a :class:`.schema.Column` that contains +one of these characters in its name. One is to specify the +:paramref:`.schema.Column.key` for columns that have such names:: + + measurement = Table('measurement', metadata, + Column('Size (meters)', Integer, key='size_meters') + ) + +Above, an INSERT statement such as ``measurement.insert()`` will use +``size_meters`` as the parameter name, and a SQL expression such as +``measurement.c.size_meters > 10`` will derive the bound parameter name +from the ``size_meters`` key as well. + +.. versionchanged:: 1.0.0 - SQL expressions will use :attr:`.Column.key` + as the source of naming when anonymous bound parameters are created + in SQL expressions; previously, this behavior only applied to + :meth:`.Table.insert` and :meth:`.Table.update` parameter names. + +The other solution is to use a positional format; psycopg2 allows use of the +"format" paramstyle, which can be passed to +:paramref:`.create_engine.paramstyle`:: + + engine = create_engine( + 'postgresql://scott:tiger@localhost:5432/test', paramstyle='format') + +With the above engine, instead of a statement like:: + + INSERT INTO measurement ("Size (meters)") VALUES (%(Size (meters))s) + {'Size (meters)': 1} + +we instead see:: + + INSERT INTO measurement ("Size (meters)") VALUES (%s) + (1, ) + +Where above, the dictionary style is converted into a tuple with positional +style. -* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support - this feature. What this essentially means from a psycopg2 point of view is that the cursor is - created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows - are not immediately pre-fetched and buffered after statement execution, but are instead left - on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy` - uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows - at a time are fetched over the wire to reduce conversational overhead. -* *use_native_unicode* - Enable the usage of Psycopg2 "native unicode" mode per connection. True - by default. -* *isolation_level* - Sets the transaction isolation level for each transaction - within the engine. Valid isolation levels are `READ_COMMITTED`, - `READ_UNCOMMITTED`, `REPEATABLE_READ`, and `SERIALIZABLE`. Transactions ------------ The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. +.. _psycopg2_isolation_level: + +Psycopg2 Transaction Isolation Level +------------------------------------- + +As discussed in :ref:`postgresql_isolation_level`, +all PostgreSQL dialects support setting of transaction isolation level +both via the ``isolation_level`` parameter passed to :func:`.create_engine`, +as well as the ``isolation_level`` argument used by +:meth:`.Connection.execution_options`. When using the psycopg2 dialect, these +options make use of psycopg2's ``set_isolation_level()`` connection method, +rather than emitting a PostgreSQL directive; this is because psycopg2's +API-level setting is always emitted at the start of each transaction in any +case. + +The psycopg2 dialect supports these constants for isolation level: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +.. versionadded:: 0.8.2 support for AUTOCOMMIT isolation level when using + psycopg2. + +.. seealso:: + + :ref:`postgresql_isolation_level` + + :ref:`pg8000_isolation_level` + + NOTICE logging --------------- -The psycopg2 dialect will log Postgresql NOTICE messages via the +The psycopg2 dialect will log PostgreSQL NOTICE messages via the ``sqlalchemy.dialects.postgresql`` logger:: import logging logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) +.. _psycopg2_hstore:: -Per-Statement Execution Options -------------------------------- +HSTORE type +------------ -The following per-statement execution options are respected: +The ``psycopg2`` DBAPI includes an extension to natively handle marshalling of +the HSTORE type. The SQLAlchemy psycopg2 dialect will enable this extension +by default when psycopg2 version 2.4 or greater is used, and +it is detected that the target database has the HSTORE type set up for use. +In other words, when the dialect makes the first +connection, a sequence like the following is performed: -* *stream_results* - Enable or disable usage of server side cursors for the SELECT-statement. - If *None* or not set, the *server_side_cursors* option of the connection is used. If - auto-commit is enabled, the option is ignored. +1. Request the available HSTORE oids using + ``psycopg2.extras.HstoreAdapter.get_oids()``. + If this function returns a list of HSTORE identifiers, we then determine + that the ``HSTORE`` extension is present. + This function is **skipped** if the version of psycopg2 installed is + less than version 2.4. + +2. If the ``use_native_hstore`` flag is at its default of ``True``, and + we've detected that ``HSTORE`` oids are available, the + ``psycopg2.extensions.register_hstore()`` extension is invoked for all + connections. + +The ``register_hstore()`` extension has the effect of **all Python +dictionaries being accepted as parameters regardless of the type of target +column in SQL**. The dictionaries are converted by this extension into a +textual HSTORE expression. If this behavior is not desired, disable the +use of the hstore extension by setting ``use_native_hstore`` to ``False`` as +follows:: + + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", + use_native_hstore=False) + +The ``HSTORE`` type is **still supported** when the +``psycopg2.extensions.register_hstore()`` extension is not used. It merely +means that the coercion between Python dictionaries and the HSTORE +string format, on both the parameter side and the result side, will take +place within SQLAlchemy's own marshalling logic, and not that of ``psycopg2`` +which may be more performant. """ +from __future__ import absolute_import -import random import re -import decimal import logging -from sqlalchemy import util -from sqlalchemy import processors -from sqlalchemy.engine import base, default -from sqlalchemy.sql import expression -from sqlalchemy.sql import operators as sql_operators -from sqlalchemy import types as sqltypes -from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, \ - PGIdentifierPreparer, PGExecutionContext, \ - ENUM, ARRAY +from ... import util, exc +import decimal +from ... import processors +from ...engine import result as _result +from ...sql import expression +from ... import types as sqltypes +from .base import PGDialect, PGCompiler, \ + PGIdentifierPreparer, PGExecutionContext, \ + ENUM, _DECIMAL_TYPES, _FLOAT_TYPES,\ + _INT_TYPES, UUID +from .hstore import HSTORE +from .json import JSON, JSONB + +try: + from uuid import UUID as _python_UUID +except ImportError: + _python_UUID = None logger = logging.getLogger('sqlalchemy.dialects.postgresql') @@ -80,82 +341,113 @@ class _PGNumeric(sqltypes.Numeric): def result_processor(self, dialect, coltype): if self.asdecimal: - if coltype in (700, 701): - return processors.to_decimal_processor_factory(decimal.Decimal) - elif coltype == 1700: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, + self._effective_decimal_return_scale) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: # pg8000 returns Decimal natively for 1700 return None else: - raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype) else: - if coltype in (700, 701): + if coltype in _FLOAT_TYPES: # pg8000 returns float natively for 701 return None - elif coltype == 1700: + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: return processors.to_float else: - raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype) + class _PGEnum(ENUM): - def __init__(self, *arg, **kw): - super(_PGEnum, self).__init__(*arg, **kw) - if self.convert_unicode: - self.convert_unicode = "force" + def result_processor(self, dialect, coltype): + if self.native_enum and util.py2k and self.convert_unicode is True: + # we can't easily use PG's extensions here because + # the OID is on the fly, and we need to give it a python + # function anyway - not really worth it. + self.convert_unicode = "force_nocheck" + return super(_PGEnum, self).result_processor(dialect, coltype) -class _PGArray(ARRAY): - def __init__(self, *arg, **kw): - super(_PGArray, self).__init__(*arg, **kw) - # FIXME: this check won't work for setups that - # have convert_unicode only on their create_engine(). - if isinstance(self.item_type, sqltypes.String) and \ - self.item_type.convert_unicode: - self.item_type.convert_unicode = "force" -# When we're handed literal SQL, ensure it's a SELECT-query. Since -# 8.3, combining cursors and "FOR UPDATE" has been fine. -SERVER_SIDE_CURSOR_RE = re.compile( - r'\s*SELECT', - re.I | re.UNICODE) +class _PGHStore(HSTORE): + def bind_processor(self, dialect): + if dialect._has_native_hstore: + return None + else: + return super(_PGHStore, self).bind_processor(dialect) + + def result_processor(self, dialect, coltype): + if dialect._has_native_hstore: + return None + else: + return super(_PGHStore, self).result_processor(dialect, coltype) + + +class _PGJSON(JSON): + + def result_processor(self, dialect, coltype): + if dialect._has_native_json: + return None + else: + return super(_PGJSON, self).result_processor(dialect, coltype) + + +class _PGJSONB(JSONB): + + def result_processor(self, dialect, coltype): + if dialect._has_native_jsonb: + return None + else: + return super(_PGJSONB, self).result_processor(dialect, coltype) + + +class _PGUUID(UUID): + def bind_processor(self, dialect): + if not self.as_uuid and dialect.use_native_uuid: + nonetype = type(None) + + def process(value): + if value is not None: + value = _python_UUID(value) + return value + return process + + def result_processor(self, dialect, coltype): + if not self.as_uuid and dialect.use_native_uuid: + def process(value): + if value is not None: + value = str(value) + return value + return process + + +_server_side_id = util.counter() + class PGExecutionContext_psycopg2(PGExecutionContext): - def create_cursor(self): - # TODO: coverage for server side cursors + select.for_update() - - if self.dialect.server_side_cursors: - is_server_side = \ - self.execution_options.get('stream_results', True) and ( - (self.compiled and isinstance(self.compiled.statement, expression.Selectable) \ - or \ - ( - (not self.compiled or - isinstance(self.compiled.statement, expression._TextClause)) - and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)) - ) - ) - else: - is_server_side = self.execution_options.get('stream_results', False) - - self.__is_server_side = is_server_side - if is_server_side: - # use server-side cursors: - # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html - ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:]) - return self._connection.connection.cursor(ident) - else: - return self._connection.connection.cursor() + def create_server_side_cursor(self): + # use server-side cursors: + # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html + ident = "c_%s_%s" % (hex(id(self))[2:], + hex(_server_side_id())[2:]) + return self._dbapi_connection.cursor(ident) def get_result_proxy(self): + # TODO: ouch if logger.isEnabledFor(logging.INFO): self._log_notices(self.cursor) - - if self.__is_server_side: - return base.BufferedRowResultProxy(self) + + if self._is_server_side: + return _result.BufferedRowResultProxy(self) else: - return base.ResultProxy(self) + return _result.ResultProxy(self) def _log_notices(self, cursor): for notice in cursor.connection.notices: - # NOTICE messages have a + # NOTICE messages have a # newline character at the end logger.info(notice.rstrip()) @@ -163,9 +455,10 @@ class PGExecutionContext_psycopg2(PGExecutionContext): class PGCompiler_psycopg2(PGCompiler): - def visit_mod(self, binary, **kw): - return self.process(binary.left) + " %% " + self.process(binary.right) - + def visit_mod_binary(self, binary, operator, **kw): + return self.process(binary.left, **kw) + " %% " + \ + self.process(binary.right, **kw) + def post_process_text(self, text): return text.replace('%', '%%') @@ -175,47 +468,191 @@ class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): value = value.replace(self.escape_quote, self.escape_to_quote) return value.replace('%', '%%') + class PGDialect_psycopg2(PGDialect): driver = 'psycopg2' - supports_unicode_statements = False + if util.py2k: + supports_unicode_statements = False + + supports_server_side_cursors = True + default_paramstyle = 'pyformat' + # set to true based on psycopg2 version supports_sane_multi_rowcount = False execution_ctx_cls = PGExecutionContext_psycopg2 statement_compiler = PGCompiler_psycopg2 preparer = PGIdentifierPreparer_psycopg2 + psycopg2_version = (0, 0) + + FEATURE_VERSION_MAP = dict( + native_json=(2, 5), + native_jsonb=(2, 5, 4), + sane_multi_rowcount=(2, 0, 9), + array_oid=(2, 4, 3), + hstore_adapter=(2, 4) + ) + + _has_native_hstore = False + _has_native_json = False + _has_native_jsonb = False + + engine_config_types = PGDialect.engine_config_types.union([ + ('use_native_unicode', util.asbool), + ]) colspecs = util.update_copy( PGDialect.colspecs, { - sqltypes.Numeric : _PGNumeric, - ENUM : _PGEnum, # needs force_unicode - sqltypes.Enum : _PGEnum, # needs force_unicode - ARRAY : _PGArray, # needs force_unicode + sqltypes.Numeric: _PGNumeric, + ENUM: _PGEnum, # needs force_unicode + sqltypes.Enum: _PGEnum, # needs force_unicode + HSTORE: _PGHStore, + JSON: _PGJSON, + sqltypes.JSON: _PGJSON, + JSONB: _PGJSONB, + UUID: _PGUUID } ) - def __init__(self, server_side_cursors=False, use_native_unicode=True, **kwargs): + def __init__(self, server_side_cursors=False, use_native_unicode=True, + client_encoding=None, + use_native_hstore=True, use_native_uuid=True, + **kwargs): PGDialect.__init__(self, **kwargs) self.server_side_cursors = server_side_cursors self.use_native_unicode = use_native_unicode + self.use_native_hstore = use_native_hstore + self.use_native_uuid = use_native_uuid self.supports_unicode_binds = use_native_unicode - + self.client_encoding = client_encoding + if self.dbapi and hasattr(self.dbapi, '__version__'): + m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', + self.dbapi.__version__) + if m: + self.psycopg2_version = tuple( + int(x) + for x in m.group(1, 2, 3) + if x is not None) + + def initialize(self, connection): + super(PGDialect_psycopg2, self).initialize(connection) + self._has_native_hstore = self.use_native_hstore and \ + self._hstore_oids(connection.connection) \ + is not None + self._has_native_json = \ + self.psycopg2_version >= self.FEATURE_VERSION_MAP['native_json'] + self._has_native_jsonb = \ + self.psycopg2_version >= self.FEATURE_VERSION_MAP['native_jsonb'] + + # http://initd.org/psycopg/docs/news.html#what-s-new-in-psycopg-2-0-9 + self.supports_sane_multi_rowcount = \ + self.psycopg2_version >= \ + self.FEATURE_VERSION_MAP['sane_multi_rowcount'] + @classmethod def dbapi(cls): - psycopg = __import__('psycopg2') - return psycopg - + import psycopg2 + return psycopg2 + + @classmethod + def _psycopg2_extensions(cls): + from psycopg2 import extensions + return extensions + + @classmethod + def _psycopg2_extras(cls): + from psycopg2 import extras + return extras + + @util.memoized_property + def _isolation_lookup(self): + extensions = self._psycopg2_extensions() + return { + 'AUTOCOMMIT': extensions.ISOLATION_LEVEL_AUTOCOMMIT, + 'READ COMMITTED': extensions.ISOLATION_LEVEL_READ_COMMITTED, + 'READ UNCOMMITTED': extensions.ISOLATION_LEVEL_READ_UNCOMMITTED, + 'REPEATABLE READ': extensions.ISOLATION_LEVEL_REPEATABLE_READ, + 'SERIALIZABLE': extensions.ISOLATION_LEVEL_SERIALIZABLE + } + + def set_isolation_level(self, connection, level): + try: + level = self._isolation_lookup[level.replace('_', ' ')] + except KeyError: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" % + (level, self.name, ", ".join(self._isolation_lookup)) + ) + + connection.set_isolation_level(level) + def on_connect(self): - base_on_connect = super(PGDialect_psycopg2, self).on_connect() + extras = self._psycopg2_extras() + extensions = self._psycopg2_extensions() + + fns = [] + if self.client_encoding is not None: + def on_connect(conn): + conn.set_client_encoding(self.client_encoding) + fns.append(on_connect) + + if self.isolation_level is not None: + def on_connect(conn): + self.set_isolation_level(conn, self.isolation_level) + fns.append(on_connect) + + if self.dbapi and self.use_native_uuid: + def on_connect(conn): + extras.register_uuid(None, conn) + fns.append(on_connect) + if self.dbapi and self.use_native_unicode: - extensions = __import__('psycopg2.extensions').extensions - def connect(conn): + def on_connect(conn): extensions.register_type(extensions.UNICODE, conn) - if base_on_connect: - base_on_connect(conn) - return connect + extensions.register_type(extensions.UNICODEARRAY, conn) + fns.append(on_connect) + + if self.dbapi and self.use_native_hstore: + def on_connect(conn): + hstore_oids = self._hstore_oids(conn) + if hstore_oids is not None: + oid, array_oid = hstore_oids + kw = {'oid': oid} + if util.py2k: + kw['unicode'] = True + if self.psycopg2_version >= \ + self.FEATURE_VERSION_MAP['array_oid']: + kw['array_oid'] = array_oid + extras.register_hstore(conn, **kw) + fns.append(on_connect) + + if self.dbapi and self._json_deserializer: + def on_connect(conn): + if self._has_native_json: + extras.register_default_json( + conn, loads=self._json_deserializer) + if self._has_native_jsonb: + extras.register_default_jsonb( + conn, loads=self._json_deserializer) + fns.append(on_connect) + + if fns: + def on_connect(conn): + for fn in fns: + fn(conn) + return on_connect else: - return base_on_connect + return None + + @util.memoized_instancemethod + def _hstore_oids(self, conn): + if self.psycopg2_version >= self.FEATURE_VERSION_MAP['hstore_adapter']: + extras = self._psycopg2_extras() + oids = extras.HstoreAdapter.get_oids(conn) + if oids is not None and oids[0]: + return oids[0:2] + return None def create_connect_args(self, url): opts = url.translate_connect_args(username='user') @@ -224,16 +661,42 @@ class PGDialect_psycopg2(PGDialect): opts.update(url.query) return ([], opts) - def is_disconnect(self, e): - if isinstance(e, self.dbapi.OperationalError): - return 'closed the connection' in str(e) or 'connection not open' in str(e) - elif isinstance(e, self.dbapi.InterfaceError): - return 'connection already closed' in str(e) or 'cursor already closed' in str(e) - elif isinstance(e, self.dbapi.ProgrammingError): - # yes, it really says "losed", not "closed" - return "losed the connection unexpectedly" in str(e) - else: - return False + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error): + # check the "closed" flag. this might not be + # present on old psycopg2 versions. Also, + # this flag doesn't actually help in a lot of disconnect + # situations, so don't rely on it. + if getattr(connection, 'closed', False): + return True + + # checks based on strings. in the case that .closed + # didn't cut it, fall back onto these. + str_e = str(e).partition("\n")[0] + for msg in [ + # these error messages from libpq: interfaces/libpq/fe-misc.c + # and interfaces/libpq/fe-secure.c. + 'terminating connection', + 'closed the connection', + 'connection not open', + 'could not receive data from server', + 'could not send data to server', + # psycopg2 client errors, psycopg2/conenction.h, + # psycopg2/cursor.h + 'connection already closed', + 'cursor already closed', + # not sure where this path is originally from, it may + # be obsolete. It really says "losed", not "closed". + 'losed the connection unexpectedly', + # these can occur in newer SSL + 'connection has been closed unexpectedly', + 'SSL SYSCALL error: Bad file descriptor', + 'SSL SYSCALL error: EOF detected', + 'SSL error: decryption failed or bad record mac', + ]: + idx = str_e.find(msg) + if idx >= 0 and '"' not in str_e[:idx]: + return True + return False dialect = PGDialect_psycopg2 - diff --git a/sqlalchemy/dialects/postgresql/pypostgresql.py b/sqlalchemy/dialects/postgresql/pypostgresql.py index 2e7ea20..ab77493 100644 --- a/sqlalchemy/dialects/postgresql/pypostgresql.py +++ b/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -1,18 +1,25 @@ -"""Support for the PostgreSQL database via py-postgresql. +# postgresql/pypostgresql.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php -Connecting ----------- - -URLs are of the form `postgresql+pypostgresql://user@password@host:port/dbname[?key=value&key=value...]`. +""" +.. dialect:: postgresql+pypostgresql + :name: py-postgresql + :dbapi: pypostgresql + :connectstring: postgresql+pypostgresql://user:password@host:port/dbname\ +[?key=value&key=value...] + :url: http://python.projects.pgfoundry.org/ """ -from sqlalchemy.engine import default -import decimal -from sqlalchemy import util -from sqlalchemy import types as sqltypes -from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext -from sqlalchemy import processors +from ... import util +from ... import types as sqltypes +from .base import PGDialect, PGExecutionContext +from ... import processors + class PGNumeric(sqltypes.Numeric): def bind_processor(self, dialect): @@ -24,9 +31,11 @@ class PGNumeric(sqltypes.Numeric): else: return processors.to_float + class PGExecutionContext_pypostgresql(PGExecutionContext): pass + class PGDialect_pypostgresql(PGDialect): driver = 'pypostgresql' @@ -36,7 +45,7 @@ class PGDialect_pypostgresql(PGDialect): default_paramstyle = 'pyformat' # requires trunk version to support sane rowcounts - # TODO: use dbapi version information to set this flag appropariately + # TODO: use dbapi version information to set this flag appropriately supports_sane_rowcount = True supports_sane_multi_rowcount = False @@ -44,8 +53,10 @@ class PGDialect_pypostgresql(PGDialect): colspecs = util.update_copy( PGDialect.colspecs, { - sqltypes.Numeric : PGNumeric, - sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used + sqltypes.Numeric: PGNumeric, + + # prevents PGNumeric from being used + sqltypes.Float: sqltypes.Float, } ) @@ -54,6 +65,23 @@ class PGDialect_pypostgresql(PGDialect): from postgresql.driver import dbapi20 return dbapi20 + _DBAPI_ERROR_NAMES = [ + "Error", + "InterfaceError", "DatabaseError", "DataError", + "OperationalError", "IntegrityError", "InternalError", + "ProgrammingError", "NotSupportedError" + ] + + @util.memoized_property + def dbapi_exception_translation_map(self): + if self.dbapi is None: + return {} + + return dict( + (getattr(self.dbapi, name).__name__, name) + for name in self._DBAPI_ERROR_NAMES + ) + def create_connect_args(self, url): opts = url.translate_connect_args(username='user') if 'port' in opts: @@ -63,7 +91,7 @@ class PGDialect_pypostgresql(PGDialect): opts.update(url.query) return ([], opts) - def is_disconnect(self, e): + def is_disconnect(self, e, connection, cursor): return "connection is closed" in str(e) dialect = PGDialect_pypostgresql diff --git a/sqlalchemy/dialects/postgresql/zxjdbc.py b/sqlalchemy/dialects/postgresql/zxjdbc.py index a886901..f3cfbb8 100644 --- a/sqlalchemy/dialects/postgresql/zxjdbc.py +++ b/sqlalchemy/dialects/postgresql/zxjdbc.py @@ -1,19 +1,46 @@ -"""Support for the PostgreSQL database via the zxjdbc JDBC connector. - -JDBC Driver ------------ - -The official Postgresql JDBC driver is at http://jdbc.postgresql.org/. +# postgresql/zxjdbc.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php """ -from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector -from sqlalchemy.dialects.postgresql.base import PGDialect +.. dialect:: postgresql+zxjdbc + :name: zxJDBC for Jython + :dbapi: zxjdbc + :connectstring: postgresql+zxjdbc://scott:tiger@localhost/db + :driverurl: http://jdbc.postgresql.org/ + + +""" +from ...connectors.zxJDBC import ZxJDBCConnector +from .base import PGDialect, PGExecutionContext + + +class PGExecutionContext_zxjdbc(PGExecutionContext): + + def create_cursor(self): + cursor = self._dbapi_connection.cursor() + cursor.datahandler = self.dialect.DataHandler(cursor.datahandler) + return cursor + class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect): jdbc_db_name = 'postgresql' jdbc_driver_name = 'org.postgresql.Driver' + execution_ctx_cls = PGExecutionContext_zxjdbc + + supports_native_decimal = True + + def __init__(self, *args, **kwargs): + super(PGDialect_zxjdbc, self).__init__(*args, **kwargs) + from com.ziclix.python.sql.handler import PostgresqlDataHandler + self.DataHandler = PostgresqlDataHandler + def _get_server_version_info(self, connection): - return tuple(int(x) for x in connection.connection.dbversion.split('.')) + parts = connection.connection.dbversion.split('.') + return tuple(int(x) for x in parts) dialect = PGDialect_zxjdbc diff --git a/sqlalchemy/dialects/type_migration_guidelines.txt b/sqlalchemy/dialects/type_migration_guidelines.txt index c26b65e..e6be205 100644 --- a/sqlalchemy/dialects/type_migration_guidelines.txt +++ b/sqlalchemy/dialects/type_migration_guidelines.txt @@ -5,20 +5,20 @@ Rules for Migrating TypeEngine classes to 0.6 a. Specifying behavior which needs to occur for bind parameters or result row columns. - + b. Specifying types that are entirely specific to the database in use and have no analogue in the sqlalchemy.types package. - + c. Specifying types where there is an analogue in sqlalchemy.types, but the database in use takes vendor-specific flags for those types. d. If a TypeEngine class doesn't provide any of this, it should be *removed* from the dialect. - + 2. the TypeEngine classes are *no longer* used for generating DDL. Dialects now have a TypeCompiler subclass which uses the same visit_XXX model as -other compilers. +other compilers. 3. the "ischema_names" and "colspecs" dictionaries are now required members on the Dialect class. @@ -29,7 +29,7 @@ the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this end users would never need to use _PGNumeric directly. However, if a dialect-specific type is specifying a type *or* arguments that are not present generically, it should match the real name of the type on that backend, in uppercase. E.g. postgresql.INET, -mysql.ENUM, postgresql.ARRAY. +mysql.ENUM, postgresql.ARRAY. Or follow this handy flowchart: @@ -61,8 +61,8 @@ Or follow this handy flowchart: | v the type should - subclass the - UPPERCASE + subclass the + UPPERCASE type in types.py (i.e. class BLOB(types.BLOB)) @@ -85,15 +85,15 @@ Example 4. MySQL has a SET type, there's no analogue for this in types.py. So MySQL names it SET in the dialect's base.py, and it subclasses types.String, since it ultimately deals with strings. -Example 5. Postgresql has a DATETIME type. The DBAPIs handle dates correctly, -and no special arguments are used in PG's DDL beyond what types.py provides. -Postgresql dialect therefore imports types.DATETIME into its base.py. +Example 5. PostgreSQL has a DATETIME type. The DBAPIs handle dates correctly, +and no special arguments are used in PG's DDL beyond what types.py provides. +PostgreSQL dialect therefore imports types.DATETIME into its base.py. Ideally one should be able to specify a schema using names imported completely from a dialect, all matching the real name on that backend: from sqlalchemy.dialects.postgresql import base as pg - + t = Table('mytable', metadata, Column('id', pg.INTEGER, primary_key=True), Column('name', pg.VARCHAR(300)), @@ -110,36 +110,36 @@ indicate a special type only available in this database, it must be *removed* fr module and from this dictionary. 6. "ischema_names" indicates string descriptions of types as returned from the database -linked to TypeEngine classes. +linked to TypeEngine classes. a. The string name should be matched to the most specific type possible within sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which - case it points to a dialect type. *It doesn't matter* if the dialect has it's + case it points to a dialect type. *It doesn't matter* if the dialect has its own subclass of that type with special bind/result behavior - reflect to the types.py UPPERCASE type as much as possible. With very few exceptions, all types should reflect to an UPPERCASE type. - + b. If the dialect contains a matching dialect-specific type that takes extra arguments which the generic one does not, then point to the dialect-specific type. E.g. mssql.VARCHAR takes a "collation" parameter which should be preserved. - + 5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by a subclass of compiler.GenericTypeCompiler. a. your TypeCompiler class will receive generic and uppercase types from sqlalchemy.types. Do not assume the presence of dialect-specific attributes on these types. - + b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with methods that produce a different DDL name. Uppercase types don't do any kind of "guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in all cases, regardless of whether or not that type is legal on the backend database. - + c. the visit_UPPERCASE methods *should* be overridden with methods that add additional - arguments and flags to those types. - + arguments and flags to those types. + d. the visit_lowercase methods are overridden to provide an interpretation of a generic type. E.g. visit_large_binary() might be overridden to say "return self.visit_BIT(type_)". - + e. visit_lowercase methods should *never* render strings directly - it should always be via calling a visit_UPPERCASE() method. diff --git a/sqlalchemy/engine/__init__.py b/sqlalchemy/engine/__init__.py index 9b3dbed..2a6c68d 100644 --- a/sqlalchemy/engine/__init__.py +++ b/sqlalchemy/engine/__init__.py @@ -1,5 +1,6 @@ # engine/__init__.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -9,7 +10,7 @@ The engine package defines the basic components used to interface DB-API modules with higher-level statement construction, connection-management, execution and result contexts. The primary -"entry point" class into this package is the Engine and it's public +"entry point" class into this package is the Engine and its public constructor ``create_engine()``. This package includes: @@ -50,94 +51,125 @@ url.py within a URL. """ -# not sure what this was used for -#import sqlalchemy.databases +from .interfaces import ( + Connectable, + CreateEnginePlugin, + Dialect, + ExecutionContext, + ExceptionContext, -from sqlalchemy.engine.base import ( + # backwards compat + Compiled, + TypeCompiler +) + +from .base import ( + Connection, + Engine, + NestedTransaction, + RootTransaction, + Transaction, + TwoPhaseTransaction, +) + +from .result import ( + BaseRowProxy, BufferedColumnResultProxy, BufferedColumnRow, BufferedRowResultProxy, - Compiled, - Connectable, - Connection, - Dialect, - Engine, - ExecutionContext, - NestedTransaction, + FullyBufferedResultProxy, ResultProxy, - RootTransaction, RowProxy, - Transaction, - TwoPhaseTransaction, - TypeCompiler - ) -from sqlalchemy.engine import strategies -from sqlalchemy import util +) + +from .util import ( + connection_memoize +) -__all__ = ( - 'BufferedColumnResultProxy', - 'BufferedColumnRow', - 'BufferedRowResultProxy', - 'Compiled', - 'Connectable', - 'Connection', - 'Dialect', - 'Engine', - 'ExecutionContext', - 'NestedTransaction', - 'ResultProxy', - 'RootTransaction', - 'RowProxy', - 'Transaction', - 'TwoPhaseTransaction', - 'TypeCompiler', - 'create_engine', - 'engine_from_config', - ) +from . import util, strategies +# backwards compat +from ..sql import ddl default_strategy = 'plain' + + def create_engine(*args, **kwargs): - """Create a new Engine instance. + """Create a new :class:`.Engine` instance. - The standard method of specifying the engine is via URL as the - first positional argument, to indicate the appropriate database - dialect and connection arguments, with additional keyword - arguments sent as options to the dialect and resulting Engine. + The standard calling form is to send the URL as the + first positional argument, usually a string + that indicates database dialect and connection arguments:: - The URL is a string in the form - ``dialect+driver://user:password@host/dbname[?key=value..]``, where - ``dialect`` is a database name such as ``mysql``, ``oracle``, - ``postgresql``, etc., and ``driver`` the name of a DBAPI, such as - ``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively, + + engine = create_engine("postgresql://scott:tiger@localhost/test") + + Additional keyword arguments may then follow it which + establish various options on the resulting :class:`.Engine` + and its underlying :class:`.Dialect` and :class:`.Pool` + constructs:: + + engine = create_engine("mysql://scott:tiger@hostname/dbname", + encoding='latin1', echo=True) + + The string form of the URL is + ``dialect[+driver]://user:password@host/dbname[?key=value..]``, where + ``dialect`` is a database name such as ``mysql``, ``oracle``, + ``postgresql``, etc., and ``driver`` the name of a DBAPI, such as + ``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively, the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`. - `**kwargs` takes a wide variety of options which are routed - towards their appropriate components. Arguments may be - specific to the Engine, the underlying Dialect, as well as the - Pool. Specific dialects also accept keyword arguments that + ``**kwargs`` takes a wide variety of options which are routed + towards their appropriate components. Arguments may be specific to + the :class:`.Engine`, the underlying :class:`.Dialect`, as well as the + :class:`.Pool`. Specific dialects also accept keyword arguments that are unique to that dialect. Here, we describe the parameters - that are common to most ``create_engine()`` usage. + that are common to most :func:`.create_engine()` usage. - :param assert_unicode: Deprecated. A warning is raised in all cases when a non-Unicode - object is passed when SQLAlchemy would coerce into an encoding - (note: but **not** when the DBAPI handles unicode objects natively). - To suppress or raise this warning to an - error, use the Python warnings filter documented at: - http://docs.python.org/library/warnings.html + Once established, the newly resulting :class:`.Engine` will + request a connection from the underlying :class:`.Pool` once + :meth:`.Engine.connect` is called, or a method which depends on it + such as :meth:`.Engine.execute` is invoked. The :class:`.Pool` in turn + will establish the first actual DBAPI connection when this request + is received. The :func:`.create_engine` call itself does **not** + establish any actual DBAPI connections directly. + + .. seealso:: + + :doc:`/core/engines` + + :doc:`/dialects/index` + + :ref:`connections_toplevel` + + :param case_sensitive=True: if False, result column names + will match in a case-insensitive fashion, that is, + ``row['SomeColumn']``. + + .. versionchanged:: 0.8 + By default, result row names match case-sensitively. + In version 0.7 and prior, all matches were case-insensitive. :param connect_args: a dictionary of options which will be passed directly to the DBAPI's ``connect()`` method as - additional keyword arguments. + additional keyword arguments. See the example + at :ref:`custom_dbapi_args`. - :param convert_unicode=False: if set to True, all - String/character based types will convert Unicode values to raw - byte values going into the database, and all raw byte values to - Python Unicode coming out in result sets. This is an - engine-wide method to provide unicode conversion across the - board. For unicode conversion on a column-by-column level, use - the ``Unicode`` column type instead, described in `types`. + :param convert_unicode=False: if set to True, sets + the default behavior of ``convert_unicode`` on the + :class:`.String` type to ``True``, regardless + of a setting of ``False`` on an individual + :class:`.String` type, thus causing all :class:`.String` + -based columns + to accommodate Python ``unicode`` objects. This flag + is useful as an engine-wide setting when using a + DBAPI that does not natively support Python + ``unicode`` objects and raises an error when + one is received (such as pyodbc with FreeTDS). + + See :class:`.String` for further details on + what this flag indicates. :param creator: a callable which returns a DBAPI connection. This creation function will be passed to the underlying @@ -160,23 +192,105 @@ def create_engine(*args, **kwargs): :ref:`dbengine_logging` for information on how to configure logging directly. - :param encoding='utf-8': the encoding to use for all Unicode - translations, both by engine-wide unicode conversion as well as - the ``Unicode`` type object. + :param encoding: Defaults to ``utf-8``. This is the string + encoding used by SQLAlchemy for string encode/decode + operations which occur within SQLAlchemy, **outside of + the DBAPI.** Most modern DBAPIs feature some degree of + direct support for Python ``unicode`` objects, + what you see in Python 2 as a string of the form + ``u'some string'``. For those scenarios where the + DBAPI is detected as not supporting a Python ``unicode`` + object, this encoding is used to determine the + source/destination encoding. It is **not used** + for those cases where the DBAPI handles unicode + directly. + + To properly configure a system to accommodate Python + ``unicode`` objects, the DBAPI should be + configured to handle unicode to the greatest + degree as is appropriate - see + the notes on unicode pertaining to the specific + target database in use at :ref:`dialect_toplevel`. + + Areas where string encoding may need to be accommodated + outside of the DBAPI include zero or more of: + + * the values passed to bound parameters, corresponding to + the :class:`.Unicode` type or the :class:`.String` type + when ``convert_unicode`` is ``True``; + * the values returned in result set columns corresponding + to the :class:`.Unicode` type or the :class:`.String` + type when ``convert_unicode`` is ``True``; + * the string SQL statement passed to the DBAPI's + ``cursor.execute()`` method; + * the string names of the keys in the bound parameter + dictionary passed to the DBAPI's ``cursor.execute()`` + as well as ``cursor.setinputsizes()`` methods; + * the string column names retrieved from the DBAPI's + ``cursor.description`` attribute. + + When using Python 3, the DBAPI is required to support + *all* of the above values as Python ``unicode`` objects, + which in Python 3 are just known as ``str``. In Python 2, + the DBAPI does not specify unicode behavior at all, + so SQLAlchemy must make decisions for each of the above + values on a per-DBAPI basis - implementations are + completely inconsistent in their behavior. + + :param execution_options: Dictionary execution options which will + be applied to all connections. See + :meth:`~sqlalchemy.engine.Connection.execution_options` + + :param implicit_returning=True: When ``True``, a RETURNING- + compatible construct, if available, will be used to + fetch newly generated primary key values when a single row + INSERT statement is emitted with no existing returning() + clause. This applies to those backends which support RETURNING + or a compatible construct, including PostgreSQL, Firebird, Oracle, + Microsoft SQL Server. Set this to ``False`` to disable + the automatic usage of RETURNING. + + :param isolation_level: this string parameter is interpreted by various + dialects in order to affect the transaction isolation level of the + database connection. The parameter essentially accepts some subset of + these string arguments: ``"SERIALIZABLE"``, ``"REPEATABLE_READ"``, + ``"READ_COMMITTED"``, ``"READ_UNCOMMITTED"`` and ``"AUTOCOMMIT"``. + Behavior here varies per backend, and + individual dialects should be consulted directly. + + Note that the isolation level can also be set on a per-:class:`.Connection` + basis as well, using the + :paramref:`.Connection.execution_options.isolation_level` + feature. + + .. seealso:: + + :attr:`.Connection.default_isolation_level` - view default level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`.Connection` isolation level + + :ref:`SQLite Transaction Isolation ` + + :ref:`PostgreSQL Transaction Isolation ` + + :ref:`MySQL Transaction Isolation ` + + :ref:`session_transaction_isolation` - for the ORM :param label_length=None: optional integer value which limits the size of dynamically generated column labels to that many characters. If less than 6, labels are generated as "_(counter)". If ``None``, the value of ``dialect.max_identifier_length`` is used instead. - - :param listeners: A list of one or more - :class:`~sqlalchemy.interfaces.PoolListener` objects which will + + :param listeners: A list of one or more + :class:`~sqlalchemy.interfaces.PoolListener` objects which will receive connection pool events. - + :param logging_name: String identifier which will be used within the "name" field of logging records generated within the - "sqlalchemy.engine" logger. Defaults to a hexstring of the + "sqlalchemy.engine" logger. Defaults to a hexstring of the object's id. :param max_overflow=10: the number of connections to allow in @@ -184,10 +298,24 @@ def create_engine(*args, **kwargs): opened above and beyond the pool_size setting, which defaults to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`. - :param module=None: used by database implementations which - support multiple DBAPI modules, this is a reference to a DBAPI2 - module to be used instead of the engine's default module. For - PostgreSQL, the default is psycopg2. For Oracle, it's cx_Oracle. + :param module=None: reference to a Python module object (the module + itself, not its string name). Specifies an alternate DBAPI module to + be used by the engine's dialect. Each sub-dialect references a + specific DBAPI which will be imported before first connect. This + parameter causes the import to be bypassed, and the given module to + be used instead. Can be used for testing of DBAPIs as well as to + inject "mock" DBAPI implementations into the :class:`.Engine`. + + :param paramstyle=None: The `paramstyle `_ + to use when rendering bound parameters. This style defaults to the + one recommended by the DBAPI itself, which is retrieved from the + ``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept + more than one paramstyle, and in particular it may be desirable + to change a "named" paramstyle into a "positional" one, or vice versa. + When this attribute is passed, it should be one of the values + ``"qmark"``, ``"numeric"``, ``"named"``, ``"format"`` or + ``"pyformat"``, and should correspond to a parameter style known + to be supported by the DBAPI in use. :param pool=None: an already-constructed instance of :class:`~sqlalchemy.pool.Pool`, such as a @@ -195,7 +323,7 @@ def create_engine(*args, **kwargs): pool will be used directly as the underlying connection pool for the engine, bypassing whatever connection parameters are present in the URL argument. For information on constructing - connection pools manually, see `pooling`. + connection pools manually, see :ref:`pooling_toplevel`. :param poolclass=None: a :class:`~sqlalchemy.pool.Pool` subclass, which will be used to create a connection pool @@ -205,70 +333,102 @@ def create_engine(*args, **kwargs): of pool to be used. :param pool_logging_name: String identifier which will be used within - the "name" field of logging records generated within the - "sqlalchemy.pool" logger. Defaults to a hexstring of the object's + the "name" field of logging records generated within the + "sqlalchemy.pool" logger. Defaults to a hexstring of the object's id. :param pool_size=5: the number of connections to keep open - inside the connection pool. This used with :class:`~sqlalchemy.pool.QueuePool` as - well as :class:`~sqlalchemy.pool.SingletonThreadPool`. + inside the connection pool. This used with + :class:`~sqlalchemy.pool.QueuePool` as + well as :class:`~sqlalchemy.pool.SingletonThreadPool`. With + :class:`~sqlalchemy.pool.QueuePool`, a ``pool_size`` setting + of 0 indicates no limit; to disable pooling, set ``poolclass`` to + :class:`~sqlalchemy.pool.NullPool` instead. :param pool_recycle=-1: this setting causes the pool to recycle connections after the given number of seconds has passed. It defaults to -1, or no timeout. For example, setting to 3600 means connections will be recycled after one hour. Note that - MySQL in particular will ``disconnect automatically`` if no + MySQL in particular will disconnect automatically if no activity is detected on a connection for eight hours (although this is configurable with the MySQLDB connection itself and the server configuration as well). + :param pool_reset_on_return='rollback': set the "reset on return" + behavior of the pool, which is whether ``rollback()``, + ``commit()``, or nothing is called upon connections + being returned to the pool. See the docstring for + ``reset_on_return`` at :class:`.Pool`. + + .. versionadded:: 0.7.6 + :param pool_timeout=30: number of seconds to wait before giving up on getting a connection from the pool. This is only used with :class:`~sqlalchemy.pool.QueuePool`. - :param strategy='plain': used to invoke alternate :class:`~sqlalchemy.engine.base.Engine.` - implementations. Currently available is the ``threadlocal`` - strategy, which is described in :ref:`threadlocal_strategy`. - + :param strategy='plain': selects alternate engine implementations. + Currently available are: + + * the ``threadlocal`` strategy, which is described in + :ref:`threadlocal_strategy`; + * the ``mock`` strategy, which dispatches all statement + execution to a function passed as the argument ``executor``. + See `example in the FAQ + `_. + + :param executor=None: a function taking arguments + ``(sql, *multiparams, **params)``, to which the ``mock`` strategy will + dispatch all statement execution. Used only by ``strategy='mock'``. + """ strategy = kwargs.pop('strategy', default_strategy) strategy = strategies.strategies[strategy] return strategy.create(*args, **kwargs) + def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs): """Create a new Engine instance using a configuration dictionary. - The dictionary is typically produced from a config file where keys - are prefixed, such as sqlalchemy.url, sqlalchemy.echo, etc. The - 'prefix' argument indicates the prefix to be searched for. + The dictionary is typically produced from a config file. + + The keys of interest to ``engine_from_config()`` should be prefixed, e.g. + ``sqlalchemy.url``, ``sqlalchemy.echo``, etc. The 'prefix' argument + indicates the prefix to be searched for. Each matching key (after the + prefix is stripped) is treated as though it were the corresponding keyword + argument to a :func:`.create_engine` call. + + The only required key is (assuming the default prefix) ``sqlalchemy.url``, + which provides the :ref:`database URL `. A select set of keyword arguments will be "coerced" to their - expected type based on string values. In a future release, this - functionality will be expanded and include dialect-specific - arguments. + expected type based on string values. The set of arguments + is extensible per-dialect using the ``engine_config_types`` accessor. + + :param configuration: A dictionary (typically produced from a config file, + but this is not a requirement). Items whose keys start with the value + of 'prefix' will have that prefix stripped, and will then be passed to + :ref:`create_engine`. + + :param prefix: Prefix to match and then strip from keys + in 'configuration'. + + :param kwargs: Each keyword argument to ``engine_from_config()`` itself + overrides the corresponding item taken from the 'configuration' + dictionary. Keyword arguments should *not* be prefixed. + """ - opts = _coerce_config(configuration, prefix) - opts.update(kwargs) - url = opts.pop('url') - return create_engine(url, **opts) - -def _coerce_config(configuration, prefix): - """Convert configuration values to expected types.""" - options = dict((key[len(prefix):], configuration[key]) for key in configuration if key.startswith(prefix)) - for option, type_ in ( - ('convert_unicode', bool), - ('pool_timeout', int), - ('echo', bool), - ('echo_pool', bool), - ('pool_recycle', int), - ('pool_size', int), - ('max_overflow', int), - ('pool_threadlocal', bool), - ): - util.coerce_kw_type(options, option, type_) - return options + options['_coerce_config'] = True + options.update(kwargs) + url = options.pop('url') + return create_engine(url, **options) + + +__all__ = ( + 'create_engine', + 'engine_from_config', +) diff --git a/sqlalchemy/engine/base.py b/sqlalchemy/engine/base.py index dc42ed9..91f4493 100644 --- a/sqlalchemy/engine/base.py +++ b/sqlalchemy/engine/base.py @@ -1,931 +1,605 @@ # engine/base.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +from __future__ import with_statement +"""Defines :class:`.Connection` and :class:`.Engine`. -"""Basic components for SQL execution and interfacing with DB-API. - -Defines the basic components used to interface DB-API modules with -higher-level statement-construction, connection-management, execution -and result contexts. """ -__all__ = [ - 'BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultProxy', - 'Compiled', 'Connectable', 'Connection', 'Dialect', 'Engine', - 'ExecutionContext', 'NestedTransaction', 'ResultProxy', 'RootTransaction', - 'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction', - 'connection_memoize'] -import inspect, StringIO, sys, operator -from itertools import izip -from sqlalchemy import exc, schema, util, types, log -from sqlalchemy.sql import expression - -class Dialect(object): - """Define the behavior of a specific database and DB-API combination. - - Any aspect of metadata definition, SQL query generation, - execution, result-set handling, or anything else which varies - between databases is defined under the general category of the - Dialect. The Dialect acts as a factory for other - database-specific object implementations including - ExecutionContext, Compiled, DefaultGenerator, and TypeEngine. - - All Dialects implement the following attributes: - - name - identifying name for the dialect from a DBAPI-neutral point of view - (i.e. 'sqlite') - - driver - identifying name for the dialect's DBAPI - - positional - True if the paramstyle for this Dialect is positional. - - paramstyle - the paramstyle to be used (some DB-APIs support multiple - paramstyles). - - convert_unicode - True if Unicode conversion should be applied to all ``str`` - types. - - encoding - type of encoding to use for unicode, usually defaults to - 'utf-8'. - - statement_compiler - a :class:`~Compiled` class used to compile SQL statements - - ddl_compiler - a :class:`~Compiled` class used to compile DDL statements - - server_version_info - a tuple containing a version number for the DB backend in use. - This value is only available for supporting dialects, and is - typically populated during the initial connection to the database. - - default_schema_name - the name of the default schema. This value is only available for - supporting dialects, and is typically populated during the - initial connection to the database. - - execution_ctx_cls - a :class:`ExecutionContext` class used to handle statement execution - - execute_sequence_format - either the 'tuple' or 'list' type, depending on what cursor.execute() - accepts for the second argument (they vary). - - preparer - a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to - quote identifiers. - - supports_alter - ``True`` if the database supports ``ALTER TABLE``. - - max_identifier_length - The maximum length of identifier names. - - supports_unicode_statements - Indicate whether the DB-API can receive SQL statements as Python - unicode strings - - supports_unicode_binds - Indicate whether the DB-API can receive string bind parameters - as Python unicode strings - - supports_sane_rowcount - Indicate whether the dialect properly implements rowcount for - ``UPDATE`` and ``DELETE`` statements. - - supports_sane_multi_rowcount - Indicate whether the dialect properly implements rowcount for - ``UPDATE`` and ``DELETE`` statements when executed via - executemany. - - preexecute_autoincrement_sequences - True if 'implicit' primary key functions must be executed separately - in order to get their value. This is currently oriented towards - Postgresql. - - implicit_returning - use RETURNING or equivalent during INSERT execution in order to load - newly generated primary keys and other column defaults in one execution, - which are then available via inserted_primary_key. - If an insert statement has returning() specified explicitly, - the "implicit" functionality is not used and inserted_primary_key - will not be available. - - dbapi_type_map - A mapping of DB-API type objects present in this Dialect's - DB-API implementation mapped to TypeEngine implementations used - by the dialect. - - This is used to apply types to result sets based on the DB-API - types present in cursor.description; it only takes effect for - result sets against textual statements where no explicit - typemap was present. - - colspecs - A dictionary of TypeEngine classes from sqlalchemy.types mapped - to subclasses that are specific to the dialect class. This - dictionary is class-level only and is not accessed from the - dialect instance itself. - - supports_default_values - Indicates if the construct ``INSERT INTO tablename DEFAULT - VALUES`` is supported - - supports_sequences - Indicates if the dialect supports CREATE SEQUENCE or similar. - - sequences_optional - If True, indicates if the "optional" flag on the Sequence() construct - should signal to not generate a CREATE SEQUENCE. Applies only to - dialects that support sequences. Currently used only to allow Postgresql - SERIAL to be used on a column that specifies Sequence() for usage on - other backends. - - supports_native_enum - Indicates if the dialect supports a native ENUM construct. - This will prevent types.Enum from generating a CHECK - constraint when that type is used. - - supports_native_boolean - Indicates if the dialect supports a native boolean construct. - This will prevent types.Boolean from generating a CHECK - constraint when that type is used. - - """ - - def create_connect_args(self, url): - """Build DB-API compatible connection arguments. - - Given a :class:`~sqlalchemy.engine.url.URL` object, returns a tuple - consisting of a `*args`/`**kwargs` suitable to send directly - to the dbapi's connect function. - - """ - - raise NotImplementedError() - - @classmethod - def type_descriptor(cls, typeobj): - """Transform a generic type to a dialect-specific type. - - Dialect classes will usually use the - :func:`~sqlalchemy.types.adapt_type` function in the types module to - make this job easy. - - The returned result is cached *per dialect class* so can - contain no dialect-instance state. - - """ - - raise NotImplementedError() - - def initialize(self, connection): - """Called during strategized creation of the dialect with a connection. - - Allows dialects to configure options based on server version info or - other properties. - - The connection passed here is a SQLAlchemy Connection object, - with full capabilities. - - The initalize() method of the base dialect should be called via - super(). - - """ - - pass - - def reflecttable(self, connection, table, include_columns=None): - """Load table description from the database. - - Given a :class:`~sqlalchemy.engine.Connection` and a - :class:`~sqlalchemy.schema.Table` object, reflect its columns and - properties from the database. If include_columns (a list or - set) is specified, limit the autoload to the given column - names. - - The default implementation uses the - :class:`~sqlalchemy.engine.reflection.Inspector` interface to - provide the output, building upon the granular table/column/ - constraint etc. methods of :class:`Dialect`. - - """ - - raise NotImplementedError() - - def get_columns(self, connection, table_name, schema=None, **kw): - """Return information about columns in `table_name`. - - Given a :class:`~sqlalchemy.engine.Connection`, a string - `table_name`, and an optional string `schema`, return column - information as a list of dictionaries with these keys: - - name - the column's name - - type - [sqlalchemy.types#TypeEngine] - - nullable - boolean - - default - the column's default value - - autoincrement - boolean - - sequence - a dictionary of the form - {'name' : str, 'start' :int, 'increment': int} - - Additional column attributes may be present. - """ - - raise NotImplementedError() - - def get_primary_keys(self, connection, table_name, schema=None, **kw): - """Return information about primary keys in `table_name`. - - Given a :class:`~sqlalchemy.engine.Connection`, a string - `table_name`, and an optional string `schema`, return primary - key information as a list of column names. - """ - - raise NotImplementedError() - - def get_foreign_keys(self, connection, table_name, schema=None, **kw): - """Return information about foreign_keys in `table_name`. - - Given a :class:`~sqlalchemy.engine.Connection`, a string - `table_name`, and an optional string `schema`, return foreign - key information as a list of dicts with these keys: - - name - the constraint's name - - constrained_columns - a list of column names that make up the foreign key - - referred_schema - the name of the referred schema - - referred_table - the name of the referred table - - referred_columns - a list of column names in the referred table that correspond to - constrained_columns - """ - - raise NotImplementedError() - - def get_table_names(self, connection, schema=None, **kw): - """Return a list of table names for `schema`.""" - - raise NotImplementedError - - def get_view_names(self, connection, schema=None, **kw): - """Return a list of all view names available in the database. - - schema: - Optional, retrieve names from a non-default schema. - """ - - raise NotImplementedError() - - def get_view_definition(self, connection, view_name, schema=None, **kw): - """Return view definition. - - Given a :class:`~sqlalchemy.engine.Connection`, a string - `view_name`, and an optional string `schema`, return the view - definition. - """ - - raise NotImplementedError() - - def get_indexes(self, connection, table_name, schema=None, **kw): - """Return information about indexes in `table_name`. - - Given a :class:`~sqlalchemy.engine.Connection`, a string - `table_name` and an optional string `schema`, return index - information as a list of dictionaries with these keys: - - name - the index's name - - column_names - list of column names in order - - unique - boolean - """ - - raise NotImplementedError() - - def normalize_name(self, name): - """convert the given name to lowercase if it is detected as case insensitive. - - this method is only used if the dialect defines requires_name_normalize=True. - - """ - raise NotImplementedError() - - def denormalize_name(self, name): - """convert the given name to a case insensitive identifier for the backend - if it is an all-lowercase name. - - this method is only used if the dialect defines requires_name_normalize=True. - - """ - raise NotImplementedError() - - def has_table(self, connection, table_name, schema=None): - """Check the existence of a particular table in the database. - - Given a :class:`~sqlalchemy.engine.Connection` object and a string - `table_name`, return True if the given table (possibly within - the specified `schema`) exists in the database, False - otherwise. - """ - - raise NotImplementedError() - - def has_sequence(self, connection, sequence_name, schema=None): - """Check the existence of a particular sequence in the database. - - Given a :class:`~sqlalchemy.engine.Connection` object and a string - `sequence_name`, return True if the given sequence exists in - the database, False otherwise. - """ - - raise NotImplementedError() - - def _get_server_version_info(self, connection): - """Retrieve the server version info from the given connection. - - This is used by the default implementation to populate the - "server_version_info" attribute and is called exactly - once upon first connect. - - """ - - raise NotImplementedError() - - def _get_default_schema_name(self, connection): - """Return the string name of the currently selected schema from the given connection. - - This is used by the default implementation to populate the - "default_schema_name" attribute and is called exactly - once upon first connect. - - """ - - raise NotImplementedError() - - def do_begin(self, connection): - """Provide an implementation of *connection.begin()*, given a DB-API connection.""" - - raise NotImplementedError() - - def do_rollback(self, connection): - """Provide an implementation of *connection.rollback()*, given a DB-API connection.""" - - raise NotImplementedError() - - def create_xid(self): - """Create a two-phase transaction ID. - - This id will be passed to do_begin_twophase(), - do_rollback_twophase(), do_commit_twophase(). Its format is - unspecified. - """ - - raise NotImplementedError() - - def do_commit(self, connection): - """Provide an implementation of *connection.commit()*, given a DB-API connection.""" - - raise NotImplementedError() - - def do_savepoint(self, connection, name): - """Create a savepoint with the given name on a SQLAlchemy connection.""" - - raise NotImplementedError() - - def do_rollback_to_savepoint(self, connection, name): - """Rollback a SQL Alchemy connection to the named savepoint.""" - - raise NotImplementedError() - - def do_release_savepoint(self, connection, name): - """Release the named savepoint on a SQL Alchemy connection.""" - - raise NotImplementedError() - - def do_begin_twophase(self, connection, xid): - """Begin a two phase transaction on the given connection.""" - - raise NotImplementedError() - - def do_prepare_twophase(self, connection, xid): - """Prepare a two phase transaction on the given connection.""" - - raise NotImplementedError() - - def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): - """Rollback a two phase transaction on the given connection.""" - - raise NotImplementedError() - - def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): - """Commit a two phase transaction on the given connection.""" - - raise NotImplementedError() - - def do_recover_twophase(self, connection): - """Recover list of uncommited prepared two phase transaction identifiers on the given connection.""" - - raise NotImplementedError() - - def do_executemany(self, cursor, statement, parameters, context=None): - """Provide an implementation of *cursor.executemany(statement, parameters)*.""" - - raise NotImplementedError() - - def do_execute(self, cursor, statement, parameters, context=None): - """Provide an implementation of *cursor.execute(statement, parameters)*.""" - - raise NotImplementedError() - - def is_disconnect(self, e): - """Return True if the given DB-API error indicates an invalid connection""" - - raise NotImplementedError() - - def on_connect(self): - """return a callable which sets up a newly created DBAPI connection. - - The callable accepts a single argument "conn" which is the - DBAPI connection itself. It has no return value. - - This is used to set dialect-wide per-connection options such as isolation - modes, unicode modes, etc. - - If a callable is returned, it will be assembled into a pool listener - that receives the direct DBAPI connection, with all wrappers removed. - - If None is returned, no listener will be generated. - - """ - return None - - -class ExecutionContext(object): - """A messenger object for a Dialect that corresponds to a single execution. - - ExecutionContext should have these data members: - - connection - Connection object which can be freely used by default value - generators to execute SQL. This Connection should reference the - same underlying connection/transactional resources of - root_connection. - - root_connection - Connection object which is the source of this ExecutionContext. This - Connection may have close_with_result=True set, in which case it can - only be used once. - - dialect - dialect which created this ExecutionContext. - - cursor - DB-API cursor procured from the connection, - - compiled - if passed to constructor, sqlalchemy.engine.base.Compiled object - being executed, - - statement - string version of the statement to be executed. Is either - passed to the constructor, or must be created from the - sql.Compiled object by the time pre_exec() has completed. - - parameters - bind parameters passed to the execute() method. For compiled - statements, this is a dictionary or list of dictionaries. For - textual statements, it should be in a format suitable for the - dialect's paramstyle (i.e. dict or list of dicts for non - positional, list or list of lists/tuples for positional). - - isinsert - True if the statement is an INSERT. - - isupdate - True if the statement is an UPDATE. - - should_autocommit - True if the statement is a "committable" statement. - - postfetch_cols - a list of Column objects for which a server-side default or - inline SQL expression value was fired off. Applies to inserts - and updates. - """ - - def create_cursor(self): - """Return a new cursor generated from this ExecutionContext's connection. - - Some dialects may wish to change the behavior of - connection.cursor(), such as postgresql which may return a PG - "server side" cursor. - """ - - raise NotImplementedError() - - def pre_exec(self): - """Called before an execution of a compiled statement. - - If a compiled statement was passed to this ExecutionContext, - the `statement` and `parameters` datamembers must be - initialized after this statement is complete. - """ - - raise NotImplementedError() - - def post_exec(self): - """Called after the execution of a compiled statement. - - If a compiled statement was passed to this ExecutionContext, - the `last_insert_ids`, `last_inserted_params`, etc. - datamembers should be available after this method completes. - """ - - raise NotImplementedError() - - def result(self): - """Return a result object corresponding to this ExecutionContext. - - Returns a ResultProxy. - """ - - raise NotImplementedError() - - def handle_dbapi_exception(self, e): - """Receive a DBAPI exception which occured upon execute, result fetch, etc.""" - - raise NotImplementedError() - - def should_autocommit_text(self, statement): - """Parse the given textual statement and return True if it refers to a "committable" statement""" - - raise NotImplementedError() - - def last_inserted_params(self): - """Return a dictionary of the full parameter dictionary for the last compiled INSERT statement. - - Includes any ColumnDefaults or Sequences that were pre-executed. - """ - - raise NotImplementedError() - - def last_updated_params(self): - """Return a dictionary of the full parameter dictionary for the last compiled UPDATE statement. - - Includes any ColumnDefaults that were pre-executed. - """ - - raise NotImplementedError() - - def lastrow_has_defaults(self): - """Return True if the last INSERT or UPDATE row contained - inlined or database-side defaults. - """ - - raise NotImplementedError() - - def get_rowcount(self): - """Return the number of rows produced (by a SELECT query) - or affected (by an INSERT/UPDATE/DELETE statement). - - Note that this row count may not be properly implemented - in some dialects; this is indicated by the - ``supports_sane_rowcount`` and ``supports_sane_multi_rowcount`` - dialect attributes. - - """ - - raise NotImplementedError() - - -class Compiled(object): - """Represent a compiled SQL or DDL expression. - - The ``__str__`` method of the ``Compiled`` object should produce - the actual text of the statement. ``Compiled`` objects are - specific to their underlying database dialect, and also may - or may not be specific to the columns referenced within a - particular set of bind parameters. In no case should the - ``Compiled`` object be dependent on the actual values of those - bind parameters, even though it may reference those values as - defaults. - """ - - def __init__(self, dialect, statement, bind=None): - """Construct a new ``Compiled`` object. - - :param dialect: ``Dialect`` to compile against. - - :param statement: ``ClauseElement`` to be compiled. - - :param bind: Optional Engine or Connection to compile this statement against. - """ - - self.dialect = dialect - self.statement = statement - self.bind = bind - self.can_execute = statement.supports_execution - - def compile(self): - """Produce the internal string representation of this element.""" - - self.string = self.process(self.statement) - - @property - def sql_compiler(self): - """Return a Compiled that is capable of processing SQL expressions. - - If this compiler is one, it would likely just return 'self'. - - """ - - raise NotImplementedError() - - def process(self, obj, **kwargs): - return obj._compiler_dispatch(self, **kwargs) - - def __str__(self): - """Return the string text of the generated SQL or DDL.""" - - return self.string or '' - - def construct_params(self, params=None): - """Return the bind params for this compiled object. - - :param params: a dict of string/object pairs whos values will - override bind values compiled in to the - statement. - """ - - raise NotImplementedError() - - @property - def params(self): - """Return the bind params for this compiled object.""" - return self.construct_params() - - def execute(self, *multiparams, **params): - """Execute this compiled object.""" - - e = self.bind - if e is None: - raise exc.UnboundExecutionError("This Compiled object is not bound to any Engine or Connection.") - return e._execute_compiled(self, multiparams, params) - - def scalar(self, *multiparams, **params): - """Execute this compiled object and return the result's scalar value.""" - - return self.execute(*multiparams, **params).scalar() - - -class TypeCompiler(object): - """Produces DDL specification for TypeEngine objects.""" - - def __init__(self, dialect): - self.dialect = dialect - - def process(self, type_): - return type_._compiler_dispatch(self) - - -class Connectable(object): - """Interface for an object which supports execution of SQL constructs. - - The two implementations of ``Connectable`` are :class:`Connection` and - :class:`Engine`. - - Connectable must also implement the 'dialect' member which references a - :class:`Dialect` instance. - """ - - def contextual_connect(self): - """Return a Connection object which may be part of an ongoing context.""" - - raise NotImplementedError() - - def create(self, entity, **kwargs): - """Create a table or index given an appropriate schema object.""" - - raise NotImplementedError() - - def drop(self, entity, **kwargs): - """Drop a table or index given an appropriate schema object.""" - - raise NotImplementedError() - - def execute(self, object, *multiparams, **params): - raise NotImplementedError() - - def _execute_clauseelement(self, elem, multiparams=None, params=None): - raise NotImplementedError() +import sys +from .. import exc, util, log, interfaces +from ..sql import util as sql_util +from ..sql import schema +from .interfaces import Connectable, ExceptionContext +from .util import _distill_params +import contextlib class Connection(Connectable): """Provides high-level functionality for a wrapped DB-API connection. - Provides execution support for string-based SQL statements as well - as ClauseElement, Compiled and DefaultGenerator objects. Provides - a begin method to return Transaction objects. + Provides execution support for string-based SQL statements as well as + :class:`.ClauseElement`, :class:`.Compiled` and :class:`.DefaultGenerator` + objects. Provides a :meth:`begin` method to return :class:`.Transaction` + objects. - The Connection object is **not** thread-safe. + The Connection object is **not** thread-safe. While a Connection can be + shared among threads using properly synchronized access, it is still + possible that the underlying DBAPI connection may not support shared + access between threads. Check the DBAPI documentation for details. + + The Connection object represents a single dbapi connection checked out + from the connection pool. In this state, the connection pool has no affect + upon the connection, including its expiration or timeout state. For the + connection pool to properly manage connections, connections should be + returned to the connection pool (i.e. ``connection.close()``) whenever the + connection is not in use. .. index:: single: thread safety; Connection + """ - _execution_options = util.frozendict() - + + schema_for_object = schema._schema_getter(None) + """Return the ".schema" attribute for an object. + + Used for :class:`.Table`, :class:`.Sequence` and similar objects, + and takes into account + the :paramref:`.Connection.execution_options.schema_translate_map` + parameter. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`schema_translating` + + """ + def __init__(self, engine, connection=None, close_with_result=False, - _branch=False, _execution_options=None): + _branch_from=None, _execution_options=None, + _dispatch=None, + _has_events=None): """Construct a new Connection. - Connection objects are typically constructed by an - :class:`~sqlalchemy.engine.Engine`, see the ``connect()`` and - ``contextual_connect()`` methods of Engine. + The constructor here is not public and is only called only by an + :class:`.Engine`. See :meth:`.Engine.connect` and + :meth:`.Engine.contextual_connect` methods. + """ self.engine = engine - self.__connection = connection or engine.raw_connection() - self.__transaction = None - self.should_close_with_result = close_with_result - self.__savepoint_seq = 0 - self.__branch = _branch - self.__invalid = False - self._echo = self.engine._should_log_info() - if _execution_options: - self._execution_options = self._execution_options.union(_execution_options) + self.dialect = engine.dialect + self.__branch_from = _branch_from + self.__branch = _branch_from is not None + + if _branch_from: + self.__connection = connection + self._execution_options = _execution_options + self._echo = _branch_from._echo + self.should_close_with_result = False + self.dispatch = _dispatch + self._has_events = _branch_from._has_events + self.schema_for_object = _branch_from.schema_for_object + else: + self.__connection = connection \ + if connection is not None else engine.raw_connection() + self.__transaction = None + self.__savepoint_seq = 0 + self.should_close_with_result = close_with_result + self.__invalid = False + self.__can_reconnect = True + self._echo = self.engine._should_log_info() + + if _has_events is None: + # if _has_events is sent explicitly as False, + # then don't join the dispatch of the engine; we don't + # want to handle any of the engine's events in that case. + self.dispatch = self.dispatch._join(engine.dispatch) + self._has_events = _has_events or ( + _has_events is None and engine._has_events) + + assert not _execution_options + self._execution_options = engine._execution_options + + if self._has_events or self.engine._has_events: + self.dispatch.engine_connect(self, self.__branch) def _branch(self): """Return a new Connection which references this Connection's engine and connection; but does not have close_with_result enabled, and also whose close() method does nothing. - This is used to execute "sub" statements within a single execution, - usually an INSERT statement. - """ + The Core uses this very sparingly, only in the case of + custom SQL default functions that are to be INSERTed as the + primary key of a row where we need to get the value back, so we have + to invoke it distinctly - this is a very uncommon case. - return self.engine.Connection(self.engine, self.__connection, _branch=True) - - def execution_options(self, **opt): - """ Set non-SQL options for the connection which take effect during execution. - - The method returns a copy of this :class:`Connection` which references - the same underlying DBAPI connection, but also defines the given execution - options which will take effect for a call to :meth:`execute`. As the new - :class:`Connection` references the same underlying resource, it is probably - best to ensure that the copies would be discarded immediately, which - is implicit if used as in:: - - result = connection.execution_options(stream_results=True).execute(stmt) - - The options are the same as those accepted by - :meth:`sqlalchemy.sql.expression.Executable.execution_options`. + Userland code accesses _branch() when the connect() or + contextual_connect() methods are called. The branched connection + acts as much as possible like the parent, except that it stays + connected when a close() event occurs. """ - return self.engine.Connection( - self.engine, self.__connection, - _branch=self.__branch, _execution_options=opt) - + if self.__branch_from: + return self.__branch_from._branch() + else: + return self.engine._connection_cls( + self.engine, + self.__connection, + _branch_from=self, + _execution_options=self._execution_options, + _has_events=self._has_events, + _dispatch=self.dispatch) + @property - def dialect(self): - "Dialect used by this Connection." + def _root(self): + """return the 'root' connection. - return self.engine.dialect + Returns 'self' if this connection is not a branch, else + returns the root connection from which we ultimately branched. + + """ + + if self.__branch_from: + return self.__branch_from + else: + return self + + def _clone(self): + """Create a shallow copy of this Connection. + + """ + c = self.__class__.__new__(self.__class__) + c.__dict__ = self.__dict__.copy() + return c + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + def execution_options(self, **opt): + r""" Set non-SQL options for the connection which take effect + during execution. + + The method returns a copy of this :class:`.Connection` which references + the same underlying DBAPI connection, but also defines the given + execution options which will take effect for a call to + :meth:`execute`. As the new :class:`.Connection` references the same + underlying resource, it's usually a good idea to ensure that the copies + will be discarded immediately, which is implicit if used as in:: + + result = connection.execution_options(stream_results=True).\ + execute(stmt) + + Note that any key/value can be passed to + :meth:`.Connection.execution_options`, and it will be stored in the + ``_execution_options`` dictionary of the :class:`.Connection`. It + is suitable for usage by end-user schemes to communicate with + event listeners, for example. + + The keywords that are currently recognized by SQLAlchemy itself + include all those listed under :meth:`.Executable.execution_options`, + as well as others that are specific to :class:`.Connection`. + + :param autocommit: Available on: Connection, statement. + When True, a COMMIT will be invoked after execution + when executed in 'autocommit' mode, i.e. when an explicit + transaction is not begun on the connection. Note that DBAPI + connections by default are always in a transaction - SQLAlchemy uses + rules applied to different kinds of statements to determine if + COMMIT will be invoked in order to provide its "autocommit" feature. + Typically, all INSERT/UPDATE/DELETE statements as well as + CREATE/DROP statements have autocommit behavior enabled; SELECT + constructs do not. Use this option when invoking a SELECT or other + specific SQL construct where COMMIT is desired (typically when + calling stored procedures and such), and an explicit + transaction is not in progress. + + :param compiled_cache: Available on: Connection. + A dictionary where :class:`.Compiled` objects + will be cached when the :class:`.Connection` compiles a clause + expression into a :class:`.Compiled` object. + It is the user's responsibility to + manage the size of this dictionary, which will have keys + corresponding to the dialect, clause element, the column + names within the VALUES or SET clause of an INSERT or UPDATE, + as well as the "batch" mode for an INSERT or UPDATE statement. + The format of this dictionary is not guaranteed to stay the + same in future releases. + + Note that the ORM makes use of its own "compiled" caches for + some operations, including flush operations. The caching + used by the ORM internally supersedes a cache dictionary + specified here. + + :param isolation_level: Available on: :class:`.Connection`. + Set the transaction isolation level for + the lifespan of this :class:`.Connection` object (*not* the + underlying DBAPI connection, for which the level is reset + to its original setting upon termination of this + :class:`.Connection` object). + + Valid values include + those string values accepted by the + :paramref:`.create_engine.isolation_level` + parameter passed to :func:`.create_engine`. These levels are + semi-database specific; see individual dialect documentation for + valid levels. + + Note that this option necessarily affects the underlying + DBAPI connection for the lifespan of the originating + :class:`.Connection`, and is not per-execution. This + setting is not removed until the underlying DBAPI connection + is returned to the connection pool, i.e. + the :meth:`.Connection.close` method is called. + + .. warning:: The ``isolation_level`` execution option should + **not** be used when a transaction is already established, that + is, the :meth:`.Connection.begin` method or similar has been + called. A database cannot change the isolation level on a + transaction in progress, and different DBAPIs and/or + SQLAlchemy dialects may implicitly roll back or commit + the transaction, or not affect the connection at all. + + .. versionchanged:: 0.9.9 A warning is emitted when the + ``isolation_level`` execution option is used after a + transaction has been started with :meth:`.Connection.begin` + or similar. + + .. note:: The ``isolation_level`` execution option is implicitly + reset if the :class:`.Connection` is invalidated, e.g. via + the :meth:`.Connection.invalidate` method, or if a + disconnection error occurs. The new connection produced after + the invalidation will not have the isolation level re-applied + to it automatically. + + .. seealso:: + + :paramref:`.create_engine.isolation_level` + - set per :class:`.Engine` isolation level + + :meth:`.Connection.get_isolation_level` - view current level + + :ref:`SQLite Transaction Isolation ` + + :ref:`PostgreSQL Transaction Isolation ` + + :ref:`MySQL Transaction Isolation ` + + :ref:`SQL Server Transaction Isolation ` + + :ref:`session_transaction_isolation` - for the ORM + + :param no_parameters: When ``True``, if the final parameter + list or dictionary is totally empty, will invoke the + statement on the cursor as ``cursor.execute(statement)``, + not passing the parameter collection at all. + Some DBAPIs such as psycopg2 and mysql-python consider + percent signs as significant only when parameters are + present; this option allows code to generate SQL + containing percent signs (and possibly other characters) + that is neutral regarding whether it's executed by the DBAPI + or piped into a script that's later invoked by + command line tools. + + .. versionadded:: 0.7.6 + + :param stream_results: Available on: Connection, statement. + Indicate to the dialect that results should be + "streamed" and not pre-buffered, if possible. This is a limitation + of many DBAPIs. The flag is currently understood only by the + psycopg2, mysqldb and pymysql dialects. + + :param schema_translate_map: Available on: Connection, Engine. + A dictionary mapping schema names to schema names, that will be + applied to the :paramref:`.Table.schema` element of each + :class:`.Table` encountered when SQL or DDL expression elements + are compiled into strings; the resulting schema name will be + converted based on presence in the map of the original name. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`schema_translating` + + """ + c = self._clone() + c._execution_options = c._execution_options.union(opt) + if self._has_events or self.engine._has_events: + self.dispatch.set_connection_execution_options(c, opt) + self.dialect.set_connection_execution_options(c, opt) + return c @property def closed(self): """Return True if this connection is closed.""" - return not self.__invalid and '_Connection__connection' not in self.__dict__ + return '_Connection__connection' not in self.__dict__ \ + and not self.__can_reconnect @property def invalidated(self): """Return True if this connection was invalidated.""" - return self.__invalid + return self._root.__invalid @property def connection(self): - "The underlying DB-API connection managed by this Connection." + """The underlying DB-API connection managed by this Connection. + + .. seealso:: + + + :ref:`dbapi_connections` + + """ try: return self.__connection except AttributeError: - if self.__invalid: - if self.__transaction is not None: - raise exc.InvalidRequestError("Can't reconnect until invalid transaction is rolled back") - self.__connection = self.engine.raw_connection() - self.__invalid = False - return self.__connection - raise exc.InvalidRequestError("This Connection is closed") + try: + return self._revalidate_connection() + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def get_isolation_level(self): + """Return the current isolation level assigned to this + :class:`.Connection`. + + This will typically be the default isolation level as determined + by the dialect, unless if the + :paramref:`.Connection.execution_options.isolation_level` + feature has been used to alter the isolation level on a + per-:class:`.Connection` basis. + + This attribute will typically perform a live SQL operation in order + to procure the current isolation level, so the value returned is the + actual level on the underlying DBAPI connection regardless of how + this state was set. Compare to the + :attr:`.Connection.default_isolation_level` accessor + which returns the dialect-level setting without performing a SQL + query. + + .. versionadded:: 0.9.9 + + .. seealso:: + + :attr:`.Connection.default_isolation_level` - view default level + + :paramref:`.create_engine.isolation_level` + - set per :class:`.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`.Connection` isolation level + + """ + try: + return self.dialect.get_isolation_level(self.connection) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + @property + def default_isolation_level(self): + """The default isolation level assigned to this :class:`.Connection`. + + This is the isolation level setting that the :class:`.Connection` + has when first procured via the :meth:`.Engine.connect` method. + This level stays in place until the + :paramref:`.Connection.execution_options.isolation_level` is used + to change the setting on a per-:class:`.Connection` basis. + + Unlike :meth:`.Connection.get_isolation_level`, this attribute is set + ahead of time from the first connection procured by the dialect, + so SQL query is not invoked when this accessor is called. + + .. versionadded:: 0.9.9 + + .. seealso:: + + :meth:`.Connection.get_isolation_level` - view current level + + :paramref:`.create_engine.isolation_level` + - set per :class:`.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`.Connection` isolation level + + """ + return self.dialect.default_isolation_level + + def _revalidate_connection(self): + if self.__branch_from: + return self.__branch_from._revalidate_connection() + if self.__can_reconnect and self.__invalid: + if self.__transaction is not None: + raise exc.InvalidRequestError( + "Can't reconnect until invalid " + "transaction is rolled back") + self.__connection = self.engine.raw_connection(_connection=self) + self.__invalid = False + return self.__connection + raise exc.ResourceClosedError("This Connection is closed") + + @property + def _connection_is_valid(self): + # use getattr() for is_valid to support exceptions raised in + # dialect initializer, where the connection is not wrapped in + # _ConnectionFairy + + return getattr(self.__connection, 'is_valid', False) + + @property + def _still_open_and_connection_is_valid(self): + return \ + not self.closed and \ + not self.invalidated and \ + getattr(self.__connection, 'is_valid', False) @property def info(self): - """A collection of per-DB-API connection instance properties.""" + """Info dictionary associated with the underlying DBAPI connection + referred to by this :class:`.Connection`, allowing user-defined + data to be associated with the connection. + + The data here will follow along with the DBAPI connection including + after it is returned to the connection pool and used again + in subsequent instances of :class:`.Connection`. + + """ return self.connection.info def connect(self): - """Returns self. + """Returns a branched version of this :class:`.Connection`. + + The :meth:`.Connection.close` method on the returned + :class:`.Connection` can be called and this + :class:`.Connection` will remain open. + + This method provides usage symmetry with + :meth:`.Engine.connect`, including for usage + with context managers. - This ``Connectable`` interface method returns self, allowing - Connections to be used interchangably with Engines in most - situations that require a bind. """ - return self + return self._branch() def contextual_connect(self, **kwargs): - """Returns self. + """Returns a branched version of this :class:`.Connection`. + + The :meth:`.Connection.close` method on the returned + :class:`.Connection` can be called and this + :class:`.Connection` will remain open. + + This method provides usage symmetry with + :meth:`.Engine.contextual_connect`, including for usage + with context managers. - This ``Connectable`` interface method returns self, allowing - Connections to be used interchangably with Engines in most - situations that require a bind. """ - return self + return self._branch() def invalidate(self, exception=None): - """Invalidate the underlying DBAPI connection associated with this Connection. + """Invalidate the underlying DBAPI connection associated with + this :class:`.Connection`. - The underlying DB-API connection is literally closed (if + The underlying DBAPI connection is literally closed (if possible), and is discarded. Its source connection pool will typically lazily create a new connection to replace it. - Upon the next usage, this Connection will attempt to reconnect - to the pool with a new connection. + Upon the next use (where "use" typically means using the + :meth:`.Connection.execute` method or similar), + this :class:`.Connection` will attempt to + procure a new DBAPI connection using the services of the + :class:`.Pool` as a source of connectivity (e.g. a "reconnection"). + + If a transaction was in progress (e.g. the + :meth:`.Connection.begin` method has been called) when + :meth:`.Connection.invalidate` method is called, at the DBAPI + level all state associated with this transaction is lost, as + the DBAPI connection is closed. The :class:`.Connection` + will not allow a reconnection to proceed until the + :class:`.Transaction` object is ended, by calling the + :meth:`.Transaction.rollback` method; until that point, any attempt at + continuing to use the :class:`.Connection` will raise an + :class:`~sqlalchemy.exc.InvalidRequestError`. + This is to prevent applications from accidentally + continuing an ongoing transactional operations despite the + fact that the transaction has been lost due to an + invalidation. + + The :meth:`.Connection.invalidate` method, just like auto-invalidation, + will at the connection pool level invoke the + :meth:`.PoolEvents.invalidate` event. + + .. seealso:: + + :ref:`pool_connection_invalidation` - Transactions in progress remain in an "opened" state (even though - the actual transaction is gone); these must be explicitly - rolled back before a reconnect on this Connection can proceed. This - is to prevent applications from accidentally continuing their transactional - operations in a non-transactional state. """ - if self.closed: - raise exc.InvalidRequestError("This Connection is closed") + if self.invalidated: + return - if self.__connection.is_valid: - self.__connection.invalidate(exception) - del self.__connection - self.__invalid = True + if self.closed: + raise exc.ResourceClosedError("This Connection is closed") + + if self._root._connection_is_valid: + self._root.__connection.invalidate(exception) + del self._root.__connection + self._root.__invalid = True def detach(self): """Detach the underlying DB-API connection from its connection pool. - This Connection instance will remain useable. When closed, + E.g.:: + + with engine.connect() as conn: + conn.detach() + conn.execute("SET search_path TO schema1, schema2") + + # work with connection + + # connection is fully closed (since we used "with:", can + # also call .close()) + + This :class:`.Connection` instance will remain usable. When closed + (or exited from a context manager context as above), the DB-API connection will be literally closed and not - returned to its pool. The pool will typically lazily create a - new connection to replace the detached connection. + returned to its originating pool. This method can be used to insulate the rest of an application from a modified state on a connection (such as a transaction - isolation level or similar). Also see - :class:`~sqlalchemy.interfaces.PoolListener` for a mechanism to modify - connection state when connections leave and return to their - connection pool. + isolation level or similar). + """ self.__connection.detach() def begin(self): - """Begin a transaction and return a Transaction handle. + """Begin a transaction and return a transaction handle. + + The returned object is an instance of :class:`.Transaction`. + This object represents the "scope" of the transaction, + which completes when either the :meth:`.Transaction.rollback` + or :meth:`.Transaction.commit` method is called. + + Nested calls to :meth:`.begin` on the same :class:`.Connection` + will return new :class:`.Transaction` objects that represent + an emulated transaction within the scope of the enclosing + transaction, that is:: + + trans = conn.begin() # outermost transaction + trans2 = conn.begin() # "nested" + trans2.commit() # does nothing + trans.commit() # actually commits + + Calls to :meth:`.Transaction.commit` only have an effect + when invoked via the outermost :class:`.Transaction` object, though the + :meth:`.Transaction.rollback` method of any of the + :class:`.Transaction` objects will roll back the + transaction. + + See also: + + :meth:`.Connection.begin_nested` - use a SAVEPOINT + + :meth:`.Connection.begin_twophase` - use a two phase /XID transaction + + :meth:`.Engine.begin` - context manager available from + :class:`.Engine`. - Repeated calls to ``begin`` on the same Connection will create - a lightweight, emulated nested transaction. Only the - outermost transaction may ``commit``. Calls to ``commit`` on - inner transactions are ignored. Any transaction in the - hierarchy may ``rollback``, however. """ + if self.__branch_from: + return self.__branch_from.begin() if self.__transaction is None: self.__transaction = RootTransaction(self) @@ -934,14 +608,21 @@ class Connection(Connectable): return Transaction(self, self.__transaction) def begin_nested(self): - """Begin a nested transaction and return a Transaction handle. + """Begin a nested transaction and return a transaction handle. + + The returned object is an instance of :class:`.NestedTransaction`. Nested transactions require SAVEPOINT support in the underlying database. Any transaction in the hierarchy may ``commit`` and ``rollback``, however the outermost transaction still controls the overall ``commit`` or ``rollback`` of the transaction of a whole. + + See also :meth:`.Connection.begin`, + :meth:`.Connection.begin_twophase`. """ + if self.__branch_from: + return self.__branch_from.begin_nested() if self.__transaction is None: self.__transaction = RootTransaction(self) @@ -950,18 +631,31 @@ class Connection(Connectable): return self.__transaction def begin_twophase(self, xid=None): - """Begin a two-phase or XA transaction and return a Transaction handle. + """Begin a two-phase or XA transaction and return a transaction + handle. + + The returned object is an instance of :class:`.TwoPhaseTransaction`, + which in addition to the methods provided by + :class:`.Transaction`, also provides a + :meth:`~.TwoPhaseTransaction.prepare` method. + + :param xid: the two phase transaction id. If not supplied, a + random id will be generated. + + See also :meth:`.Connection.begin`, + :meth:`.Connection.begin_twophase`. - :param xid: the two phase transaction id. If not supplied, a random id - will be generated. """ + if self.__branch_from: + return self.__branch_from.begin_twophase(xid=xid) + if self.__transaction is not None: raise exc.InvalidRequestError( "Cannot start a two phase transaction when a transaction " "is already in progress.") if xid is None: - xid = self.engine.dialect.create_xid(); + xid = self.engine.dialect.create_xid() self.__transaction = TwoPhaseTransaction(self, xid) return self.__transaction @@ -976,100 +670,204 @@ class Connection(Connectable): def in_transaction(self): """Return True if a transaction is in progress.""" + return self._root.__transaction is not None - return self.__transaction is not None + def _begin_impl(self, transaction): + assert not self.__branch_from - def _begin_impl(self): if self._echo: - self.engine.logger.info("BEGIN") + self.engine.logger.info("BEGIN (implicit)") + + if self._has_events or self.engine._has_events: + self.dispatch.begin(self) + try: self.engine.dialect.do_begin(self.connection) - except Exception, e: + if self.connection._reset_agent is None: + self.connection._reset_agent = transaction + except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) - raise def _rollback_impl(self): - # use getattr() for is_valid to support exceptions raised in dialect initializer, - # where we do not yet have the pool wrappers plugged in - if not self.closed and not self.invalidated and \ - getattr(self.__connection, 'is_valid', False): + assert not self.__branch_from + + if self._has_events or self.engine._has_events: + self.dispatch.rollback(self) + + if self._still_open_and_connection_is_valid: if self._echo: self.engine.logger.info("ROLLBACK") try: self.engine.dialect.do_rollback(self.connection) - self.__transaction = None - except Exception, e: + except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) - raise + finally: + if not self.__invalid and \ + self.connection._reset_agent is self.__transaction: + self.connection._reset_agent = None + self.__transaction = None else: self.__transaction = None - def _commit_impl(self): + def _commit_impl(self, autocommit=False): + assert not self.__branch_from + + if self._has_events or self.engine._has_events: + self.dispatch.commit(self) + if self._echo: self.engine.logger.info("COMMIT") try: self.engine.dialect.do_commit(self.connection) - self.__transaction = None - except Exception, e: + except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) - raise + finally: + if not self.__invalid and \ + self.connection._reset_agent is self.__transaction: + self.connection._reset_agent = None + self.__transaction = None def _savepoint_impl(self, name=None): + assert not self.__branch_from + + if self._has_events or self.engine._has_events: + self.dispatch.savepoint(self, name) + if name is None: self.__savepoint_seq += 1 name = 'sa_savepoint_%s' % self.__savepoint_seq - if self.__connection.is_valid: + if self._still_open_and_connection_is_valid: self.engine.dialect.do_savepoint(self, name) return name def _rollback_to_savepoint_impl(self, name, context): - if self.__connection.is_valid: + assert not self.__branch_from + + if self._has_events or self.engine._has_events: + self.dispatch.rollback_savepoint(self, name, context) + + if self._still_open_and_connection_is_valid: self.engine.dialect.do_rollback_to_savepoint(self, name) self.__transaction = context def _release_savepoint_impl(self, name, context): - if self.__connection.is_valid: + assert not self.__branch_from + + if self._has_events or self.engine._has_events: + self.dispatch.release_savepoint(self, name, context) + + if self._still_open_and_connection_is_valid: self.engine.dialect.do_release_savepoint(self, name) self.__transaction = context - def _begin_twophase_impl(self, xid): - if self.__connection.is_valid: - self.engine.dialect.do_begin_twophase(self, xid) + def _begin_twophase_impl(self, transaction): + assert not self.__branch_from + + if self._echo: + self.engine.logger.info("BEGIN TWOPHASE (implicit)") + if self._has_events or self.engine._has_events: + self.dispatch.begin_twophase(self, transaction.xid) + + if self._still_open_and_connection_is_valid: + self.engine.dialect.do_begin_twophase(self, transaction.xid) + + if self.connection._reset_agent is None: + self.connection._reset_agent = transaction def _prepare_twophase_impl(self, xid): - if self.__connection.is_valid: + assert not self.__branch_from + + if self._has_events or self.engine._has_events: + self.dispatch.prepare_twophase(self, xid) + + if self._still_open_and_connection_is_valid: assert isinstance(self.__transaction, TwoPhaseTransaction) self.engine.dialect.do_prepare_twophase(self, xid) def _rollback_twophase_impl(self, xid, is_prepared): - if self.__connection.is_valid: + assert not self.__branch_from + + if self._has_events or self.engine._has_events: + self.dispatch.rollback_twophase(self, xid, is_prepared) + + if self._still_open_and_connection_is_valid: assert isinstance(self.__transaction, TwoPhaseTransaction) - self.engine.dialect.do_rollback_twophase(self, xid, is_prepared) - self.__transaction = None + try: + self.engine.dialect.do_rollback_twophase( + self, xid, is_prepared) + finally: + if self.connection._reset_agent is self.__transaction: + self.connection._reset_agent = None + self.__transaction = None + else: + self.__transaction = None def _commit_twophase_impl(self, xid, is_prepared): - if self.__connection.is_valid: + assert not self.__branch_from + + if self._has_events or self.engine._has_events: + self.dispatch.commit_twophase(self, xid, is_prepared) + + if self._still_open_and_connection_is_valid: assert isinstance(self.__transaction, TwoPhaseTransaction) - self.engine.dialect.do_commit_twophase(self, xid, is_prepared) - self.__transaction = None + try: + self.engine.dialect.do_commit_twophase(self, xid, is_prepared) + finally: + if self.connection._reset_agent is self.__transaction: + self.connection._reset_agent = None + self.__transaction = None + else: + self.__transaction = None def _autorollback(self): - if not self.in_transaction(): - self._rollback_impl() + if not self._root.in_transaction(): + self._root._rollback_impl() def close(self): - """Close this Connection.""" + """Close this :class:`.Connection`. + This results in a release of the underlying database + resources, that is, the DBAPI connection referenced + internally. The DBAPI connection is typically restored + back to the connection-holding :class:`.Pool` referenced + by the :class:`.Engine` that produced this + :class:`.Connection`. Any transactional state present on + the DBAPI connection is also unconditionally released via + the DBAPI connection's ``rollback()`` method, regardless + of any :class:`.Transaction` object that may be + outstanding with regards to this :class:`.Connection`. + + After :meth:`~.Connection.close` is called, the + :class:`.Connection` is permanently in a closed state, + and will allow no further operations. + + """ + if self.__branch_from: + try: + del self.__connection + except AttributeError: + pass + finally: + self.__can_reconnect = False + return try: conn = self.__connection except AttributeError: - return - if not self.__branch: + pass + else: + conn.close() - self.__invalid = False - del self.__connection + if conn._reset_agent is self.__transaction: + conn._reset_agent = None + + # the close() process can end up invalidating us, + # as the pool will call our transaction as the "reset_agent" + # for rollback(), which can then cause an invalidation + if not self.__invalid: + del self.__connection + self.__can_reconnect = False self.__transaction = None - + def scalar(self, object, *multiparams, **params): """Executes and returns the first column of the first row. @@ -1079,223 +877,636 @@ class Connection(Connectable): return self.execute(object, *multiparams, **params).scalar() def execute(self, object, *multiparams, **params): - """Executes and returns a ResultProxy.""" + r"""Executes a SQL statement construct and returns a + :class:`.ResultProxy`. - for c in type(object).__mro__: - if c in Connection.executors: - return Connection.executors[c](self, object, multiparams, params) - else: - raise exc.InvalidRequestError("Unexecutable object type: " + str(type(object))) + :param object: The statement to be executed. May be + one of: - def __distill_params(self, multiparams, params): - """Given arguments from the calling form *multiparams, **params, return a list - of bind parameter structures, usually a list of dictionaries. + * a plain string + * any :class:`.ClauseElement` construct that is also + a subclass of :class:`.Executable`, such as a + :func:`~.expression.select` construct + * a :class:`.FunctionElement`, such as that generated + by :data:`.func`, will be automatically wrapped in + a SELECT statement, which is then executed. + * a :class:`.DDLElement` object + * a :class:`.DefaultGenerator` object + * a :class:`.Compiled` object + + :param \*multiparams/\**params: represent bound parameter + values to be used in the execution. Typically, + the format is either a collection of one or more + dictionaries passed to \*multiparams:: + + conn.execute( + table.insert(), + {"id":1, "value":"v1"}, + {"id":2, "value":"v2"} + ) + + ...or individual key/values interpreted by \**params:: + + conn.execute( + table.insert(), id=1, value="v1" + ) + + In the case that a plain SQL string is passed, and the underlying + DBAPI accepts positional bind parameters, a collection of tuples + or individual values in \*multiparams may be passed:: + + conn.execute( + "INSERT INTO table (id, value) VALUES (?, ?)", + (1, "v1"), (2, "v2") + ) + + conn.execute( + "INSERT INTO table (id, value) VALUES (?, ?)", + 1, "v1" + ) + + Note above, the usage of a question mark "?" or other + symbol is contingent upon the "paramstyle" accepted by the DBAPI + in use, which may be any of "qmark", "named", "pyformat", "format", + "numeric". See `pep-249 `_ + for details on paramstyle. + + To execute a textual SQL statement which uses bound parameters in a + DBAPI-agnostic way, use the :func:`~.expression.text` construct. - In the case of 'raw' execution which accepts positional parameters, - it may be a list of tuples or lists. - """ - - if not multiparams: - if params: - return [params] - else: - return [] - elif len(multiparams) == 1: - zero = multiparams[0] - if isinstance(zero, (list, tuple)): - if not zero or hasattr(zero[0], '__iter__'): - return zero - else: - return [zero] - elif hasattr(zero, 'keys'): - return [zero] - else: - return [[zero]] + if isinstance(object, util.string_types[0]): + return self._execute_text(object, multiparams, params) + try: + meth = object._execute_on_connection + except AttributeError: + raise exc.ObjectNotExecutableError(object) else: - if hasattr(multiparams[0], '__iter__'): - return multiparams - else: - return [multiparams] + return meth(self, multiparams, params) def _execute_function(self, func, multiparams, params): - return self._execute_clauseelement(func.select(), multiparams, params) + """Execute a sql.FunctionElement object.""" + + return self._execute_clauseelement(func.select(), + multiparams, params) def _execute_default(self, default, multiparams, params): - ctx = self.__create_execution_context() - ret = ctx._exec_default(default) + """Execute a schema.ColumnDefault object.""" + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_execute: + default, multiparams, params = \ + fn(self, default, multiparams, params) + + try: + try: + conn = self.__connection + except AttributeError: + conn = self._revalidate_connection() + + dialect = self.dialect + ctx = dialect.execution_ctx_cls._init_default( + dialect, self, conn) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + ret = ctx._exec_default(default, None) if self.should_close_with_result: self.close() + + if self._has_events or self.engine._has_events: + self.dispatch.after_execute(self, + default, multiparams, params, ret) + return ret - def _execute_ddl(self, ddl, params, multiparams): - context = self.__create_execution_context( - compiled_ddl=ddl.compile(dialect=self.dialect), - parameters=None - ) - return self.__execute_context(context) + def _execute_ddl(self, ddl, multiparams, params): + """Execute a schema.DDL object.""" + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_execute: + ddl, multiparams, params = \ + fn(self, ddl, multiparams, params) + + dialect = self.dialect + + compiled = ddl.compile( + dialect=dialect, + schema_translate_map=self.schema_for_object + if not self.schema_for_object.is_default else None) + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_ddl, + compiled, + None, + compiled + ) + if self._has_events or self.engine._has_events: + self.dispatch.after_execute(self, + ddl, multiparams, params, ret) + return ret def _execute_clauseelement(self, elem, multiparams, params): - params = self.__distill_params(multiparams, params) - if params: - keys = params[0].keys() + """Execute a sql.ClauseElement object.""" + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_execute: + elem, multiparams, params = \ + fn(self, elem, multiparams, params) + + distilled_params = _distill_params(multiparams, params) + if distilled_params: + # note this is usually dict but we support RowProxy + # as well; but dict.keys() as an iterable is OK + keys = distilled_params[0].keys() else: keys = [] - context = self.__create_execution_context( - compiled_sql=elem.compile( - dialect=self.dialect, column_keys=keys, - inline=len(params) > 1), - parameters=params - ) - return self.__execute_context(context) + dialect = self.dialect + if 'compiled_cache' in self._execution_options: + key = ( + dialect, elem, tuple(sorted(keys)), + self.schema_for_object.hash_key, + len(distilled_params) > 1 + ) + compiled_sql = self._execution_options['compiled_cache'].get(key) + if compiled_sql is None: + compiled_sql = elem.compile( + dialect=dialect, column_keys=keys, + inline=len(distilled_params) > 1, + schema_translate_map=self.schema_for_object + if not self.schema_for_object.is_default else None + ) + self._execution_options['compiled_cache'][key] = compiled_sql + else: + compiled_sql = elem.compile( + dialect=dialect, column_keys=keys, + inline=len(distilled_params) > 1, + schema_translate_map=self.schema_for_object + if not self.schema_for_object.is_default else None) + + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_compiled, + compiled_sql, + distilled_params, + compiled_sql, distilled_params + ) + if self._has_events or self.engine._has_events: + self.dispatch.after_execute(self, + elem, multiparams, params, ret) + return ret def _execute_compiled(self, compiled, multiparams, params): """Execute a sql.Compiled object.""" - context = self.__create_execution_context( - compiled_sql=compiled, - parameters=self.__distill_params(multiparams, params) - ) - return self.__execute_context(context) + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_execute: + compiled, multiparams, params = \ + fn(self, compiled, multiparams, params) + + dialect = self.dialect + parameters = _distill_params(multiparams, params) + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_compiled, + compiled, + parameters, + compiled, parameters + ) + if self._has_events or self.engine._has_events: + self.dispatch.after_execute(self, + compiled, multiparams, params, ret) + return ret def _execute_text(self, statement, multiparams, params): - parameters = self.__distill_params(multiparams, params) - context = self.__create_execution_context(statement=statement, parameters=parameters) - return self.__execute_context(context) + """Execute a string SQL statement.""" + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_execute: + statement, multiparams, params = \ + fn(self, statement, multiparams, params) + + dialect = self.dialect + parameters = _distill_params(multiparams, params) + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_statement, + statement, + parameters, + statement, parameters + ) + if self._has_events or self.engine._has_events: + self.dispatch.after_execute(self, + statement, multiparams, params, ret) + return ret + + def _execute_context(self, dialect, constructor, + statement, parameters, + *args): + """Create an :class:`.ExecutionContext` and execute, returning + a :class:`.ResultProxy`.""" + + try: + try: + conn = self.__connection + except AttributeError: + conn = self._revalidate_connection() + + context = constructor(dialect, self, conn, *args) + except BaseException as e: + self._handle_dbapi_exception( + e, + util.text_type(statement), parameters, + None, None) - def __execute_context(self, context): if context.compiled: context.pre_exec() - - if context.executemany: - self._cursor_executemany( - context.cursor, - context.statement, - context.parameters, context=context) - else: - self._cursor_execute( - context.cursor, - context.statement, - context.parameters[0], context=context) - + + cursor, statement, parameters = context.cursor, \ + context.statement, \ + context.parameters + + if not context.executemany: + parameters = parameters[0] + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_cursor_execute: + statement, parameters = \ + fn(self, cursor, statement, parameters, + context, context.executemany) + + if self._echo: + self.engine.logger.info(statement) + self.engine.logger.info( + "%r", + sql_util._repr_params(parameters, batches=10) + ) + + evt_handled = False + try: + if context.executemany: + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_executemany: + if fn(cursor, statement, parameters, context): + evt_handled = True + break + if not evt_handled: + self.dialect.do_executemany( + cursor, + statement, + parameters, + context) + elif not parameters and context.no_parameters: + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_execute_no_params: + if fn(cursor, statement, context): + evt_handled = True + break + if not evt_handled: + self.dialect.do_execute_no_params( + cursor, + statement, + context) + else: + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_execute: + if fn(cursor, statement, parameters, context): + evt_handled = True + break + if not evt_handled: + self.dialect.do_execute( + cursor, + statement, + parameters, + context) + except BaseException as e: + self._handle_dbapi_exception( + e, + statement, + parameters, + cursor, + context) + + if self._has_events or self.engine._has_events: + self.dispatch.after_cursor_execute(self, cursor, + statement, + parameters, + context, + context.executemany) + if context.compiled: context.post_exec() - - if context.isinsert and not context.executemany: - context.post_insert() - - # create a resultproxy, get rowcount/implicit RETURNING - # rows, close cursor if no further results pending - r = context.get_result_proxy()._autoclose() - if self.__transaction is None and context.should_autocommit: - self._commit_impl() - - if r.closed and self.should_close_with_result: - self.close() - - return r - - def _handle_dbapi_exception(self, e, statement, parameters, cursor, context): - if getattr(self, '_reentrant_error', False): - # Py3K - #raise exc.DBAPIError.instance(statement, parameters, e) from e - # Py2K - raise exc.DBAPIError.instance(statement, parameters, e), None, sys.exc_info()[2] - # end Py2K - self._reentrant_error = True - try: - if not isinstance(e, self.dialect.dbapi.Error): - return + if context.is_crud or context.is_text: + result = context._setup_crud_result_proxy() + else: + result = context.get_result_proxy() + if result._metadata is None: + result._soft_close() - if context: - context.handle_dbapi_exception(e) + if context.should_autocommit and self._root.__transaction is None: + self._root._commit_impl(autocommit=True) - is_disconnect = self.dialect.is_disconnect(e) - if is_disconnect: - self.invalidate(e) - self.engine.dispose() + # for "connectionless" execution, we have to close this + # Connection after the statement is complete. + if self.should_close_with_result: + # ResultProxy already exhausted rows / has no rows. + # close us now + if result._soft_closed: + self.close() else: - if cursor: - cursor.close() - self._autorollback() - if self.should_close_with_result: - self.close() - # Py3K - #raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) from e - # Py2K - raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect), None, sys.exc_info()[2] - # end Py2K - - finally: - del self._reentrant_error - - def __create_execution_context(self, **kwargs): - try: - dialect = self.engine.dialect - return dialect.execution_ctx_cls(dialect, connection=self, **kwargs) - except Exception, e: - self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None, None) - raise + # ResultProxy will close this Connection when no more + # rows to fetch. + result._autoclose_connection = True + return result def _cursor_execute(self, cursor, statement, parameters, context=None): + """Execute a statement + params on the given cursor. + + Adds appropriate logging and exception handling. + + This method is used by DefaultDialect for special-case + executions, such as for sequences and column defaults. + The path of statement execution in the majority of cases + terminates at _execute_context(). + + """ + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_cursor_execute: + statement, parameters = \ + fn(self, cursor, statement, parameters, + context, + False) + if self._echo: self.engine.logger.info(statement) self.engine.logger.info("%r", parameters) try: - self.dialect.do_execute(cursor, statement, parameters, context=context) - except Exception, e: - self._handle_dbapi_exception(e, statement, parameters, cursor, context) - raise + for fn in () if not self.dialect._has_events \ + else self.dialect.dispatch.do_execute: + if fn(cursor, statement, parameters, context): + break + else: + self.dialect.do_execute( + cursor, + statement, + parameters, + context) + except BaseException as e: + self._handle_dbapi_exception( + e, + statement, + parameters, + cursor, + context) - def _cursor_executemany(self, cursor, statement, parameters, context=None): - if self._echo: - self.engine.logger.info(statement) - self.engine.logger.info("%r", parameters) + if self._has_events or self.engine._has_events: + self.dispatch.after_cursor_execute(self, cursor, + statement, + parameters, + context, + False) + + def _safe_close_cursor(self, cursor): + """Close the given cursor, catching exceptions + and turning into log warnings. + + """ try: - self.dialect.do_executemany(cursor, statement, parameters, context=context) - except Exception, e: - self._handle_dbapi_exception(e, statement, parameters, cursor, context) - raise + cursor.close() + except Exception: + # log the error through the connection pool's logger. + self.engine.pool.logger.error( + "Error closing cursor", exc_info=True) - # poor man's multimethod/generic function thingy - executors = { - expression.FunctionElement: _execute_function, - expression.ClauseElement: _execute_clauseelement, - Compiled: _execute_compiled, - schema.SchemaItem: _execute_default, - schema.DDLElement: _execute_ddl, - basestring: _execute_text - } + _reentrant_error = False + _is_disconnect = False - def create(self, entity, **kwargs): - """Create a Table or Index given an appropriate Schema object.""" + def _handle_dbapi_exception(self, + e, + statement, + parameters, + cursor, + context): + exc_info = sys.exc_info() - return self.engine.create(entity, connection=self, **kwargs) + if context and context.exception is None: + context.exception = e - def drop(self, entity, **kwargs): - """Drop a Table or Index given an appropriate Schema object.""" + is_exit_exception = not isinstance(e, Exception) - return self.engine.drop(entity, connection=self, **kwargs) + if not self._is_disconnect: + self._is_disconnect = ( + isinstance(e, self.dialect.dbapi.Error) and + not self.closed and + self.dialect.is_disconnect( + e, + self.__connection if not self.invalidated else None, + cursor) + ) or ( + is_exit_exception and not self.closed + ) - def reflecttable(self, table, include_columns=None): - """Reflect the columns in the given string table name from the database.""" + if context: + context.is_disconnect = self._is_disconnect - return self.engine.reflecttable(table, self, include_columns) + invalidate_pool_on_disconnect = not is_exit_exception - def default_schema_name(self): - return self.engine.dialect.get_default_schema_name(self) + if self._reentrant_error: + util.raise_from_cause( + exc.DBAPIError.instance(statement, + parameters, + e, + self.dialect.dbapi.Error, + dialect=self.dialect), + exc_info + ) + self._reentrant_error = True + try: + # non-DBAPI error - if we already got a context, + # or there's no string statement, don't wrap it + should_wrap = isinstance(e, self.dialect.dbapi.Error) or \ + (statement is not None + and context is None and not is_exit_exception) + + if should_wrap: + sqlalchemy_exception = exc.DBAPIError.instance( + statement, + parameters, + e, + self.dialect.dbapi.Error, + connection_invalidated=self._is_disconnect, + dialect=self.dialect) + else: + sqlalchemy_exception = None + + newraise = None + + if (self._has_events or self.engine._has_events) and \ + not self._execution_options.get( + 'skip_user_error_events', False): + # legacy dbapi_error event + if should_wrap and context: + self.dispatch.dbapi_error(self, + cursor, + statement, + parameters, + context, + e) + + # new handle_error event + ctx = ExceptionContextImpl( + e, sqlalchemy_exception, self.engine, + self, cursor, statement, + parameters, context, self._is_disconnect, + invalidate_pool_on_disconnect) + + for fn in self.dispatch.handle_error: + try: + # handler returns an exception; + # call next handler in a chain + per_fn = fn(ctx) + if per_fn is not None: + ctx.chained_exception = newraise = per_fn + except Exception as _raised: + # handler raises an exception - stop processing + newraise = _raised + break + + if self._is_disconnect != ctx.is_disconnect: + self._is_disconnect = ctx.is_disconnect + if sqlalchemy_exception: + sqlalchemy_exception.connection_invalidated = \ + ctx.is_disconnect + + # set up potentially user-defined value for + # invalidate pool. + invalidate_pool_on_disconnect = \ + ctx.invalidate_pool_on_disconnect + + if should_wrap and context: + context.handle_dbapi_exception(e) + + if not self._is_disconnect: + if cursor: + self._safe_close_cursor(cursor) + with util.safe_reraise(warn_only=True): + self._autorollback() + + if newraise: + util.raise_from_cause(newraise, exc_info) + elif should_wrap: + util.raise_from_cause( + sqlalchemy_exception, + exc_info + ) + else: + util.reraise(*exc_info) + + finally: + del self._reentrant_error + if self._is_disconnect: + del self._is_disconnect + if not self.invalidated: + dbapi_conn_wrapper = self.__connection + if invalidate_pool_on_disconnect: + self.engine.pool._invalidate(dbapi_conn_wrapper, e) + self.invalidate(e) + if self.should_close_with_result: + self.close() + + @classmethod + def _handle_dbapi_exception_noconnection(cls, e, dialect, engine): + exc_info = sys.exc_info() + + is_disconnect = dialect.is_disconnect(e, None, None) + + should_wrap = isinstance(e, dialect.dbapi.Error) + + if should_wrap: + sqlalchemy_exception = exc.DBAPIError.instance( + None, + None, + e, + dialect.dbapi.Error, + connection_invalidated=is_disconnect) + else: + sqlalchemy_exception = None + + newraise = None + + if engine._has_events: + ctx = ExceptionContextImpl( + e, sqlalchemy_exception, engine, None, None, None, + None, None, is_disconnect, True) + for fn in engine.dispatch.handle_error: + try: + # handler returns an exception; + # call next handler in a chain + per_fn = fn(ctx) + if per_fn is not None: + ctx.chained_exception = newraise = per_fn + except Exception as _raised: + # handler raises an exception - stop processing + newraise = _raised + break + + if sqlalchemy_exception and \ + is_disconnect != ctx.is_disconnect: + sqlalchemy_exception.connection_invalidated = \ + is_disconnect = ctx.is_disconnect + + if newraise: + util.raise_from_cause(newraise, exc_info) + elif should_wrap: + util.raise_from_cause( + sqlalchemy_exception, + exc_info + ) + else: + util.reraise(*exc_info) def transaction(self, callable_, *args, **kwargs): - """Execute the given function within a transaction boundary. + r"""Execute the given function within a transaction boundary. + + The function is passed this :class:`.Connection` + as the first argument, followed by the given \*args and \**kwargs, + e.g.:: + + def do_something(conn, x, y): + conn.execute("some statement", {'x':x, 'y':y}) + + conn.transaction(do_something, 5, 10) + + The operations inside the function are all invoked within the + context of a single :class:`.Transaction`. + Upon success, the transaction is committed. If an + exception is raised, the transaction is rolled back + before propagating the exception. + + .. note:: + + The :meth:`.transaction` method is superseded by + the usage of the Python ``with:`` statement, which can + be used with :meth:`.Connection.begin`:: + + with conn.begin(): + conn.execute("some statement", {'x':5, 'y':10}) + + As well as with :meth:`.Engine.begin`:: + + with engine.begin() as conn: + conn.execute("some statement", {'x':5, 'y':10}) + + See also: + + :meth:`.Engine.begin` - engine-level transactional + context + + :meth:`.Engine.transaction` - engine-level version of + :meth:`.Connection.transaction` - This is a shortcut for explicitly calling `begin()` and `commit()` - and optionally `rollback()` when exceptions are raised. The - given `*args` and `**kwargs` will be passed to the function. - - See also transaction() on engine. - """ trans = self.begin() @@ -1304,29 +1515,89 @@ class Connection(Connectable): trans.commit() return ret except: - trans.rollback() - raise + with util.safe_reraise(): + trans.rollback() def run_callable(self, callable_, *args, **kwargs): + r"""Given a callable object or function, execute it, passing + a :class:`.Connection` as the first argument. + + The given \*args and \**kwargs are passed subsequent + to the :class:`.Connection` argument. + + This function, along with :meth:`.Engine.run_callable`, + allows a function to be run with a :class:`.Connection` + or :class:`.Engine` object without the need to know + which one is being dealt with. + + """ return callable_(self, *args, **kwargs) + def _run_visitor(self, visitorcallable, element, **kwargs): + visitorcallable(self.dialect, self, + **kwargs).traverse_single(element) + + +class ExceptionContextImpl(ExceptionContext): + """Implement the :class:`.ExceptionContext` interface.""" + + def __init__(self, exception, sqlalchemy_exception, + engine, connection, cursor, statement, parameters, + context, is_disconnect, invalidate_pool_on_disconnect): + self.engine = engine + self.connection = connection + self.sqlalchemy_exception = sqlalchemy_exception + self.original_exception = exception + self.execution_context = context + self.statement = statement + self.parameters = parameters + self.is_disconnect = is_disconnect + self.invalidate_pool_on_disconnect = invalidate_pool_on_disconnect + class Transaction(object): - """Represent a Transaction in progress. + """Represent a database transaction in progress. + + The :class:`.Transaction` object is procured by + calling the :meth:`~.Connection.begin` method of + :class:`.Connection`:: + + from sqlalchemy import create_engine + engine = create_engine("postgresql://scott:tiger@localhost/test") + connection = engine.connect() + trans = connection.begin() + connection.execute("insert into x (a, b) values (1, 2)") + trans.commit() + + The object provides :meth:`.rollback` and :meth:`.commit` + methods in order to control transaction boundaries. It + also implements a context manager interface so that + the Python ``with`` statement can be used with the + :meth:`.Connection.begin` method:: + + with connection.begin(): + connection.execute("insert into x (a, b) values (1, 2)") The Transaction object is **not** threadsafe. + See also: :meth:`.Connection.begin`, :meth:`.Connection.begin_twophase`, + :meth:`.Connection.begin_nested`. + .. index:: single: thread safety; Transaction """ def __init__(self, connection, parent): self.connection = connection - self._parent = parent or self + self._actual_parent = parent self.is_active = True + @property + def _parent(self): + return self._actual_parent or self + def close(self): - """Close this transaction. + """Close this :class:`.Transaction`. If this transaction is the base transaction in a begin/commit nesting, the transaction will rollback(). Otherwise, the @@ -1334,6 +1605,7 @@ class Transaction(object): This is used to cancel a Transaction without affecting the scope of an enclosing transaction. + """ if not self._parent.is_active: return @@ -1341,6 +1613,9 @@ class Transaction(object): self.rollback() def rollback(self): + """Roll back this :class:`.Transaction`. + + """ if not self._parent.is_active: return self._do_rollback() @@ -1350,6 +1625,8 @@ class Transaction(object): self._parent.rollback() def commit(self): + """Commit this :class:`.Transaction`.""" + if not self._parent.is_active: raise exc.InvalidRequestError("This transaction is inactive") self._do_commit() @@ -1363,7 +1640,11 @@ class Transaction(object): def __exit__(self, type, value, traceback): if type is None and self.is_active: - self.commit() + try: + self.commit() + except: + with util.safe_reraise(): + self.rollback() else: self.rollback() @@ -1371,7 +1652,7 @@ class Transaction(object): class RootTransaction(Transaction): def __init__(self, connection): super(RootTransaction, self).__init__(connection, None) - self.connection._begin_impl() + self.connection._begin_impl(self) def _do_rollback(self): if self.is_active: @@ -1383,27 +1664,53 @@ class RootTransaction(Transaction): class NestedTransaction(Transaction): + """Represent a 'nested', or SAVEPOINT transaction. + + A new :class:`.NestedTransaction` object may be procured + using the :meth:`.Connection.begin_nested` method. + + The interface is the same as that of :class:`.Transaction`. + + """ + def __init__(self, connection, parent): super(NestedTransaction, self).__init__(connection, parent) self._savepoint = self.connection._savepoint_impl() def _do_rollback(self): if self.is_active: - self.connection._rollback_to_savepoint_impl(self._savepoint, self._parent) + self.connection._rollback_to_savepoint_impl( + self._savepoint, self._parent) def _do_commit(self): if self.is_active: - self.connection._release_savepoint_impl(self._savepoint, self._parent) + self.connection._release_savepoint_impl( + self._savepoint, self._parent) class TwoPhaseTransaction(Transaction): + """Represent a two-phase transaction. + + A new :class:`.TwoPhaseTransaction` object may be procured + using the :meth:`.Connection.begin_twophase` method. + + The interface is the same as that of :class:`.Transaction` + with the addition of the :meth:`prepare` method. + + """ + def __init__(self, connection, xid): super(TwoPhaseTransaction, self).__init__(connection, None) self._is_prepared = False self.xid = xid - self.connection._begin_twophase_impl(self.xid) + self.connection._begin_twophase_impl(self) def prepare(self): + """Prepare this :class:`.TwoPhaseTransaction`. + + After a PREPARE, the transaction can be committed. + + """ if not self._parent.is_active: raise exc.InvalidRequestError("This transaction is inactive") self.connection._prepare_twophase_impl(self.xid) @@ -1418,15 +1725,45 @@ class TwoPhaseTransaction(Transaction): class Engine(Connectable, log.Identified): """ - Connects a :class:`~sqlalchemy.pool.Pool` and :class:`~sqlalchemy.engine.base.Dialect` - together to provide a source of database connectivity and behavior. - - An :class:`Engine` object is instantiated publically using the :func:`~sqlalchemy.create_engine` - function. + Connects a :class:`~sqlalchemy.pool.Pool` and + :class:`~sqlalchemy.engine.interfaces.Dialect` together to provide a + source of database connectivity and behavior. + + An :class:`.Engine` object is instantiated publicly using the + :func:`~sqlalchemy.create_engine` function. + + See also: + + :doc:`/core/engines` + + :ref:`connections_toplevel` """ - def __init__(self, pool, dialect, url, logging_name=None, echo=None, proxy=None): + _execution_options = util.immutabledict() + _has_events = False + _connection_cls = Connection + + schema_for_object = schema._schema_getter(None) + """Return the ".schema" attribute for an object. + + Used for :class:`.Table`, :class:`.Sequence` and similar objects, + and takes into account + the :paramref:`.Connection.execution_options.schema_translate_map` + parameter. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`schema_translating` + + """ + + def __init__(self, pool, dialect, url, + logging_name=None, echo=None, proxy=None, + execution_options=None + ): self.pool = pool self.url = url self.dialect = dialect @@ -1434,101 +1771,295 @@ class Engine(Connectable, log.Identified): self.logging_name = logging_name self.echo = echo self.engine = self - self.logger = log.instance_logger(self, echoflag=echo) + log.instance_logger(self, echoflag=echo) if proxy: - self.Connection = _proxy_connection_cls(Connection, proxy) - else: - self.Connection = Connection + interfaces.ConnectionProxy._adapt_listener(self, proxy) + if execution_options: + self.update_execution_options(**execution_options) + + def update_execution_options(self, **opt): + r"""Update the default execution_options dictionary + of this :class:`.Engine`. + + The given keys/values in \**opt are added to the + default execution options that will be used for + all connections. The initial contents of this dictionary + can be sent via the ``execution_options`` parameter + to :func:`.create_engine`. + + .. seealso:: + + :meth:`.Connection.execution_options` + + :meth:`.Engine.execution_options` + + """ + self._execution_options = \ + self._execution_options.union(opt) + self.dispatch.set_engine_execution_options(self, opt) + self.dialect.set_engine_execution_options(self, opt) + + def execution_options(self, **opt): + """Return a new :class:`.Engine` that will provide + :class:`.Connection` objects with the given execution options. + + The returned :class:`.Engine` remains related to the original + :class:`.Engine` in that it shares the same connection pool and + other state: + + * The :class:`.Pool` used by the new :class:`.Engine` is the + same instance. The :meth:`.Engine.dispose` method will replace + the connection pool instance for the parent engine as well + as this one. + * Event listeners are "cascaded" - meaning, the new :class:`.Engine` + inherits the events of the parent, and new events can be associated + with the new :class:`.Engine` individually. + * The logging configuration and logging_name is copied from the parent + :class:`.Engine`. + + The intent of the :meth:`.Engine.execution_options` method is + to implement "sharding" schemes where multiple :class:`.Engine` + objects refer to the same connection pool, but are differentiated + by options that would be consumed by a custom event:: + + primary_engine = create_engine("mysql://") + shard1 = primary_engine.execution_options(shard_id="shard1") + shard2 = primary_engine.execution_options(shard_id="shard2") + + Above, the ``shard1`` engine serves as a factory for + :class:`.Connection` objects that will contain the execution option + ``shard_id=shard1``, and ``shard2`` will produce :class:`.Connection` + objects that contain the execution option ``shard_id=shard2``. + + An event handler can consume the above execution option to perform + a schema switch or other operation, given a connection. Below + we emit a MySQL ``use`` statement to switch databases, at the same + time keeping track of which database we've established using the + :attr:`.Connection.info` dictionary, which gives us a persistent + storage space that follows the DBAPI connection:: + + from sqlalchemy import event + from sqlalchemy.engine import Engine + + shards = {"default": "base", shard_1: "db1", "shard_2": "db2"} + + @event.listens_for(Engine, "before_cursor_execute") + def _switch_shard(conn, cursor, stmt, + params, context, executemany): + shard_id = conn._execution_options.get('shard_id', "default") + current_shard = conn.info.get("current_shard", None) + + if current_shard != shard_id: + cursor.execute("use %s" % shards[shard_id]) + conn.info["current_shard"] = shard_id + + .. versionadded:: 0.8 + + .. seealso:: + + :meth:`.Connection.execution_options` - update execution options + on a :class:`.Connection` object. + + :meth:`.Engine.update_execution_options` - update the execution + options for a given :class:`.Engine` in place. + + """ + return OptionEngine(self, opt) @property def name(self): - "String name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``." + """String name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`.""" return self.dialect.name @property def driver(self): - "Driver name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``." + """Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`.""" return self.dialect.driver echo = log.echo_property() def __repr__(self): - return 'Engine(%s)' % str(self.url) + return 'Engine(%r)' % self.url def dispose(self): + """Dispose of the connection pool used by this :class:`.Engine`. + + This has the effect of fully closing all **currently checked in** + database connections. Connections that are still checked out + will **not** be closed, however they will no longer be associated + with this :class:`.Engine`, so when they are closed individually, + eventually the :class:`.Pool` which they are associated with will + be garbage collected and they will be closed out fully, if + not already closed on checkin. + + A new connection pool is created immediately after the old one has + been disposed. This new pool, like all SQLAlchemy connection pools, + does not make any actual connections to the database until one is + first requested, so as long as the :class:`.Engine` isn't used again, + no new connections will be made. + + .. seealso:: + + :ref:`engine_disposal` + + """ self.pool.dispose() self.pool = self.pool.recreate() - - def create(self, entity, connection=None, **kwargs): - """Create a table or index within this engine's database connection given a schema.Table object.""" - - from sqlalchemy.engine import ddl - - self._run_visitor(ddl.SchemaGenerator, entity, connection=connection, **kwargs) - - def drop(self, entity, connection=None, **kwargs): - """Drop a table or index within this engine's database connection given a schema.Table object.""" - - from sqlalchemy.engine import ddl - - self._run_visitor(ddl.SchemaDropper, entity, connection=connection, **kwargs) + self.dispatch.engine_disposed(self) def _execute_default(self, default): - connection = self.contextual_connect() - try: - return connection._execute_default(default, (), {}) - finally: - connection.close() + with self.contextual_connect() as conn: + return conn._execute_default(default, (), {}) - @property - def func(self): - return expression._FunctionGenerator(bind=self) - - def text(self, text, *args, **kwargs): - """Return a sql.text() object for performing literal queries.""" - - return expression.text(text, bind=self, *args, **kwargs) - - def _run_visitor(self, visitorcallable, element, connection=None, **kwargs): + @contextlib.contextmanager + def _optional_conn_ctx_manager(self, connection=None): if connection is None: - conn = self.contextual_connect(close_with_result=False) + with self.contextual_connect() as conn: + yield conn else: - conn = connection + yield connection + + def _run_visitor(self, visitorcallable, element, + connection=None, **kwargs): + with self._optional_conn_ctx_manager(connection) as conn: + conn._run_visitor(visitorcallable, element, **kwargs) + + class _trans_ctx(object): + def __init__(self, conn, transaction, close_with_result): + self.conn = conn + self.transaction = transaction + self.close_with_result = close_with_result + + def __enter__(self): + return self.conn + + def __exit__(self, type, value, traceback): + if type is not None: + self.transaction.rollback() + else: + self.transaction.commit() + if not self.close_with_result: + self.conn.close() + + def begin(self, close_with_result=False): + """Return a context manager delivering a :class:`.Connection` + with a :class:`.Transaction` established. + + E.g.:: + + with engine.begin() as conn: + conn.execute("insert into table (x, y, z) values (1, 2, 3)") + conn.execute("my_special_procedure(5)") + + Upon successful operation, the :class:`.Transaction` + is committed. If an error is raised, the :class:`.Transaction` + is rolled back. + + The ``close_with_result`` flag is normally ``False``, and indicates + that the :class:`.Connection` will be closed when the operation + is complete. When set to ``True``, it indicates the + :class:`.Connection` is in "single use" mode, where the + :class:`.ResultProxy` returned by the first call to + :meth:`.Connection.execute` will close the :class:`.Connection` when + that :class:`.ResultProxy` has exhausted all result rows. + + .. versionadded:: 0.7.6 + + See also: + + :meth:`.Engine.connect` - procure a :class:`.Connection` from + an :class:`.Engine`. + + :meth:`.Connection.begin` - start a :class:`.Transaction` + for a particular :class:`.Connection`. + + """ + conn = self.contextual_connect(close_with_result=close_with_result) try: - visitorcallable(self.dialect, conn, **kwargs).traverse(element) - finally: - if connection is None: + trans = conn.begin() + except: + with util.safe_reraise(): conn.close() + return Engine._trans_ctx(conn, trans, close_with_result) def transaction(self, callable_, *args, **kwargs): - """Execute the given function within a transaction boundary. + r"""Execute the given function within a transaction boundary. + + The function is passed a :class:`.Connection` newly procured + from :meth:`.Engine.contextual_connect` as the first argument, + followed by the given \*args and \**kwargs. + + e.g.:: + + def do_something(conn, x, y): + conn.execute("some statement", {'x':x, 'y':y}) + + engine.transaction(do_something, 5, 10) + + The operations inside the function are all invoked within the + context of a single :class:`.Transaction`. + Upon success, the transaction is committed. If an + exception is raised, the transaction is rolled back + before propagating the exception. + + .. note:: + + The :meth:`.transaction` method is superseded by + the usage of the Python ``with:`` statement, which can + be used with :meth:`.Engine.begin`:: + + with engine.begin() as conn: + conn.execute("some statement", {'x':5, 'y':10}) + + See also: + + :meth:`.Engine.begin` - engine-level transactional + context + + :meth:`.Connection.transaction` - connection-level version of + :meth:`.Engine.transaction` - This is a shortcut for explicitly calling `begin()` and `commit()` - and optionally `rollback()` when exceptions are raised. The - given `*args` and `**kwargs` will be passed to the function. - - The connection used is that of contextual_connect(). - - See also the similar method on Connection itself. - """ - - conn = self.contextual_connect() - try: + + with self.contextual_connect() as conn: return conn.transaction(callable_, *args, **kwargs) - finally: - conn.close() def run_callable(self, callable_, *args, **kwargs): - conn = self.contextual_connect() - try: + r"""Given a callable object or function, execute it, passing + a :class:`.Connection` as the first argument. + + The given \*args and \**kwargs are passed subsequent + to the :class:`.Connection` argument. + + This function, along with :meth:`.Connection.run_callable`, + allows a function to be run with a :class:`.Connection` + or :class:`.Engine` object without the need to know + which one is being dealt with. + + """ + with self.contextual_connect() as conn: return conn.run_callable(callable_, *args, **kwargs) - finally: - conn.close() def execute(self, statement, *multiparams, **params): + """Executes the given construct and returns a :class:`.ResultProxy`. + + The arguments are the same as those used by + :meth:`.Connection.execute`. + + Here, a :class:`.Connection` is acquired using the + :meth:`~.Engine.contextual_connect` method, and the statement executed + with that connection. The returned :class:`.ResultProxy` is flagged + such that when the :class:`.ResultProxy` is exhausted and its + underlying cursor is closed, the :class:`.Connection` created here + will also be closed, which allows its associated DBAPI connection + resource to be returned to the connection pool. + + """ + connection = self.contextual_connect(close_with_result=True) return connection.execute(statement, *multiparams, **params) @@ -1544,879 +2075,133 @@ class Engine(Connectable, log.Identified): return connection._execute_compiled(compiled, multiparams, params) def connect(self, **kwargs): - """Return a newly allocated Connection object.""" + """Return a new :class:`.Connection` object. - return self.Connection(self, **kwargs) + The :class:`.Connection` object is a facade that uses a DBAPI + connection internally in order to communicate with the database. This + connection is procured from the connection-holding :class:`.Pool` + referenced by this :class:`.Engine`. When the + :meth:`~.Connection.close` method of the :class:`.Connection` object + is called, the underlying DBAPI connection is then returned to the + connection pool, where it may be used again in a subsequent call to + :meth:`~.Engine.connect`. - def contextual_connect(self, close_with_result=False, **kwargs): - """Return a Connection object which may be newly allocated, or may be part of some ongoing context. - - This Connection is meant to be used by the various "auto-connecting" operations. """ - return self.Connection(self, self.pool.connect(), close_with_result=close_with_result, **kwargs) + return self._connection_cls(self, **kwargs) + + def contextual_connect(self, close_with_result=False, **kwargs): + """Return a :class:`.Connection` object which may be part of some + ongoing context. + + By default, this method does the same thing as :meth:`.Engine.connect`. + Subclasses of :class:`.Engine` may override this method + to provide contextual behavior. + + :param close_with_result: When True, the first :class:`.ResultProxy` + created by the :class:`.Connection` will call the + :meth:`.Connection.close` method of that connection as soon as any + pending result rows are exhausted. This is used to supply the + "connectionless execution" behavior provided by the + :meth:`.Engine.execute` method. + + """ + + return self._connection_cls( + self, + self._wrap_pool_connect(self.pool.connect, None), + close_with_result=close_with_result, + **kwargs) def table_names(self, schema=None, connection=None): """Return a list of all table names available in the database. :param schema: Optional, retrieve names from a non-default schema. - :param connection: Optional, use a specified connection. Default is the - ``contextual_connect`` for this ``Engine``. + :param connection: Optional, use a specified connection. Default is + the ``contextual_connect`` for this ``Engine``. """ - if connection is None: - conn = self.contextual_connect() - else: - conn = connection - if not schema: - schema = self.dialect.default_schema_name - try: + with self._optional_conn_ctx_manager(connection) as conn: + if not schema: + schema = self.dialect.default_schema_name return self.dialect.get_table_names(conn, schema) - finally: - if connection is None: - conn.close() - - def reflecttable(self, table, connection=None, include_columns=None): - """Given a Table object, reflects its columns and properties from the database.""" - - if connection is None: - conn = self.contextual_connect() - else: - conn = connection - try: - self.dialect.reflecttable(conn, table, include_columns) - finally: - if connection is None: - conn.close() def has_table(self, table_name, schema=None): + """Return True if the given backend has a table of the given name. + + .. seealso:: + + :ref:`metadata_reflection_inspector` - detailed schema inspection + using the :class:`.Inspector` interface. + + :class:`.quoted_name` - used to pass quoting information along + with a schema identifier. + + """ return self.run_callable(self.dialect.has_table, table_name, schema) - def raw_connection(self): - """Return a DB-API connection.""" - - return self.pool.unique_connection() - - -def _proxy_connection_cls(cls, proxy): - class ProxyConnection(cls): - def execute(self, object, *multiparams, **params): - return proxy.execute(self, super(ProxyConnection, self).execute, - object, *multiparams, **params) - - def _execute_clauseelement(self, elem, multiparams=None, params=None): - return proxy.execute(self, super(ProxyConnection, self).execute, - elem, *(multiparams or []), **(params or {})) - - def _cursor_execute(self, cursor, statement, parameters, context=None): - return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, - cursor, statement, parameters, context, False) - - def _cursor_executemany(self, cursor, statement, parameters, context=None): - return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, - cursor, statement, parameters, context, True) - - def _begin_impl(self): - return proxy.begin(self, super(ProxyConnection, self)._begin_impl) - - def _rollback_impl(self): - return proxy.rollback(self, super(ProxyConnection, self)._rollback_impl) - - def _commit_impl(self): - return proxy.commit(self, super(ProxyConnection, self)._commit_impl) - - def _savepoint_impl(self, name=None): - return proxy.savepoint(self, super(ProxyConnection, self)._savepoint_impl, name=name) - - def _rollback_to_savepoint_impl(self, name, context): - return proxy.rollback_savepoint(self, - super(ProxyConnection, self)._rollback_to_savepoint_impl, - name, context) - - def _release_savepoint_impl(self, name, context): - return proxy.release_savepoint(self, - super(ProxyConnection, self)._release_savepoint_impl, - name, context) - - def _begin_twophase_impl(self, xid): - return proxy.begin_twophase(self, - super(ProxyConnection, self)._begin_twophase_impl, xid) - - def _prepare_twophase_impl(self, xid): - return proxy.prepare_twophase(self, - super(ProxyConnection, self)._prepare_twophase_impl, xid) - - def _rollback_twophase_impl(self, xid, is_prepared): - return proxy.rollback_twophase(self, - super(ProxyConnection, self)._rollback_twophase_impl, - xid, is_prepared) - - def _commit_twophase_impl(self, xid, is_prepared): - return proxy.commit_twophase(self, - super(ProxyConnection, self)._commit_twophase_impl, - xid, is_prepared) - - return ProxyConnection - -# This reconstructor is necessary so that pickles with the C extension or -# without use the same Binary format. -try: - # We need a different reconstructor on the C extension so that we can - # add extra checks that fields have correctly been initialized by - # __setstate__. - from sqlalchemy.cresultproxy import safe_rowproxy_reconstructor - - # The extra function embedding is needed so that the reconstructor function - # has the same signature whether or not the extension is present. - def rowproxy_reconstructor(cls, state): - return safe_rowproxy_reconstructor(cls, state) -except ImportError: - def rowproxy_reconstructor(cls, state): - obj = cls.__new__(cls) - obj.__setstate__(state) - return obj - -try: - from sqlalchemy.cresultproxy import BaseRowProxy -except ImportError: - class BaseRowProxy(object): - __slots__ = ('_parent', '_row', '_processors', '_keymap') - - def __init__(self, parent, row, processors, keymap): - """RowProxy objects are constructed by ResultProxy objects.""" - - self._parent = parent - self._row = row - self._processors = processors - self._keymap = keymap - - def __reduce__(self): - return (rowproxy_reconstructor, - (self.__class__, self.__getstate__())) - - def values(self): - """Return the values represented by this RowProxy as a list.""" - return list(self) - - def __iter__(self): - for processor, value in izip(self._processors, self._row): - if processor is None: - yield value - else: - yield processor(value) - - def __len__(self): - return len(self._row) - - def __getitem__(self, key): - try: - processor, index = self._keymap[key] - except KeyError: - processor, index = self._parent._key_fallback(key) - except TypeError: - if isinstance(key, slice): - l = [] - for processor, value in izip(self._processors[key], - self._row[key]): - if processor is None: - l.append(value) - else: - l.append(processor(value)) - return tuple(l) - else: - raise - if index is None: - raise exc.InvalidRequestError( - "Ambiguous column name '%s' in result set! " - "try 'use_labels' option on select statement." % key) - if processor is not None: - return processor(self._row[index]) - else: - return self._row[index] - - def __getattr__(self, name): - try: - # TODO: no test coverage here - return self[name] - except KeyError, e: - raise AttributeError(e.args[0]) - - -class RowProxy(BaseRowProxy): - """Proxy values from a single cursor row. - - Mostly follows "ordered dictionary" behavior, mapping result - values to the string-based column name, the integer position of - the result in the row, as well as Column instances which can be - mapped to the original Columns that produced this result set (for - results that correspond to constructed SQL expressions). - """ - __slots__ = () - - def __contains__(self, key): - return self._parent._has_key(self._row, key) - - def __getstate__(self): - return { - '_parent': self._parent, - '_row': tuple(self) - } - - def __setstate__(self, state): - self._parent = parent = state['_parent'] - self._row = state['_row'] - self._processors = parent._processors - self._keymap = parent._keymap - - __hash__ = None - - def __eq__(self, other): - return other is self or other == tuple(self) - - def __ne__(self, other): - return not self.__eq__(other) - - def __repr__(self): - return repr(tuple(self)) - - def has_key(self, key): - """Return True if this RowProxy contains the given key.""" - - return self._parent._has_key(self._row, key) - - def items(self): - """Return a list of tuples, each tuple containing a key/value pair.""" - # TODO: no coverage here - return [(key, self[key]) for key in self.iterkeys()] - - def keys(self): - """Return the list of keys as strings represented by this RowProxy.""" - - return self._parent.keys - - def iterkeys(self): - return iter(self._parent.keys) - - def itervalues(self): - return iter(self) - - -class ResultMetaData(object): - """Handle cursor.description, applying additional info from an execution context.""" - - def __init__(self, parent, metadata): - self._processors = processors = [] - - # We do not strictly need to store the processor in the key mapping, - # though it is faster in the Python version (probably because of the - # saved attribute lookup self._processors) - self._keymap = keymap = {} - self.keys = [] - self._echo = parent._echo - context = parent.context - dialect = context.dialect - typemap = dialect.dbapi_type_map - - for i, (colname, coltype) in enumerate(m[0:2] for m in metadata): - if dialect.description_encoding: - colname = colname.decode(dialect.description_encoding) - - if '.' in colname: - # sqlite will in some circumstances prepend table name to - # colnames, so strip - origname = colname - colname = colname.split('.')[-1] - else: - origname = None - - if context.result_map: - try: - name, obj, type_ = context.result_map[colname.lower()] - except KeyError: - name, obj, type_ = \ - colname, None, typemap.get(coltype, types.NULLTYPE) - else: - name, obj, type_ = (colname, None, typemap.get(coltype, types.NULLTYPE)) - - processor = type_.dialect_impl(dialect).\ - result_processor(dialect, coltype) - - processors.append(processor) - rec = (processor, i) - - # indexes as keys. This is only needed for the Python version of - # RowProxy (the C version uses a faster path for integer indexes). - keymap[i] = rec - - # Column names as keys - if keymap.setdefault(name.lower(), rec) is not rec: - # We do not raise an exception directly because several - # columns colliding by name is not a problem as long as the - # user does not try to access them (ie use an index directly, - # or the more precise ColumnElement) - keymap[name.lower()] = (processor, None) - - # store the "origname" if we truncated (sqlite only) - if origname and \ - keymap.setdefault(origname.lower(), rec) is not rec: - keymap[origname.lower()] = (processor, None) - - if dialect.requires_name_normalize: - colname = dialect.normalize_name(colname) - - self.keys.append(colname) - if obj: - for o in obj: - keymap[o] = rec - - if self._echo: - self.logger = context.engine.logger - self.logger.debug( - "Col %r", tuple(x[0] for x in metadata)) - - def _key_fallback(self, key): - map = self._keymap - result = None - if isinstance(key, basestring): - result = map.get(key.lower()) - # fallback for targeting a ColumnElement to a textual expression - # this is a rare use case which only occurs when matching text() - # constructs to ColumnElements, and after a pickle/unpickle roundtrip - elif isinstance(key, expression.ColumnElement): - if key._label and key._label.lower() in map: - result = map[key._label.lower()] - elif hasattr(key, 'name') and key.name.lower() in map: - result = map[key.name.lower()] - if result is None: - raise exc.NoSuchColumnError( - "Could not locate column in row for column '%s'" % key) - else: - map[key] = result - return result - - def _has_key(self, row, key): - if key in self._keymap: - return True - else: - try: - self._key_fallback(key) - return True - except exc.NoSuchColumnError: - return False - - def __len__(self): - return len(self.keys) - - def __getstate__(self): - return { - '_pickled_keymap': dict( - (key, index) - for key, (processor, index) in self._keymap.iteritems() - if isinstance(key, (basestring, int)) - ), - 'keys': self.keys - } - - def __setstate__(self, state): - # the row has been processed at pickling time so we don't need any - # processor anymore - self._processors = [None for _ in xrange(len(state['keys']))] - self._keymap = keymap = {} - for key, index in state['_pickled_keymap'].iteritems(): - keymap[key] = (None, index) - self.keys = state['keys'] - self._echo = False - - -class ResultProxy(object): - """Wraps a DB-API cursor object to provide easier access to row columns. - - Individual columns may be accessed by their integer position, - case-insensitive column name, or by ``schema.Column`` - object. e.g.:: - - row = fetchone() - - col1 = row[0] # access via integer position - - col2 = row['col2'] # access via name - - col3 = row[mytable.c.mycol] # access via Column object. - - ``ResultProxy`` also handles post-processing of result column - data using ``TypeEngine`` objects, which are referenced from - the originating SQL statement that produced this result set. - - """ - - _process_row = RowProxy - out_parameters = None - _can_close_connection = False - - def __init__(self, context): - self.context = context - self.dialect = context.dialect - self.closed = False - self.cursor = context.cursor - self.connection = context.root_connection - self._echo = self.connection._echo and \ - context.engine._should_log_debug() - self._init_metadata() - - def _init_metadata(self): - metadata = self._cursor_description() - if metadata is None: - self._metadata = None - else: - self._metadata = ResultMetaData(self, metadata) - - def keys(self): - """Return the current set of string keys for rows.""" - if self._metadata: - return self._metadata.keys - else: - return [] - - @util.memoized_property - def rowcount(self): - """Return the 'rowcount' for this result. - - The 'rowcount' reports the number of rows affected - by an UPDATE or DELETE statement. It has *no* other - uses and is not intended to provide the number of rows - present from a SELECT. - - Note that this row count may not be properly implemented - in some dialects; this is indicated by - :meth:`~sqlalchemy.engine.base.ResultProxy.supports_sane_rowcount()` and - :meth:`~sqlalchemy.engine.base.ResultProxy.supports_sane_multi_rowcount()`. - - ``rowcount()`` also may not work at this time for a statement - that uses ``returning()``. - - """ - return self.context.rowcount - - @property - def lastrowid(self): - """return the 'lastrowid' accessor on the DBAPI cursor. - - This is a DBAPI specific method and is only functional - for those backends which support it, for statements - where it is appropriate. It's behavior is not - consistent across backends. - - Usage of this method is normally unnecessary; the - inserted_primary_key method provides a - tuple of primary key values for a newly inserted row, - regardless of database backend. - - """ - return self.cursor.lastrowid - - def _cursor_description(self): - """May be overridden by subclasses.""" - - return self.cursor.description - - def _autoclose(self): - """called by the Connection to autoclose cursors that have no pending results - beyond those used by an INSERT/UPDATE/DELETE with no explicit RETURNING clause. - - """ - if self.context.isinsert: - if self.context._is_implicit_returning: - self.context._fetch_implicit_returning(self) - self.close(_autoclose_connection=False) - elif not self.context._is_explicit_returning: - self.close(_autoclose_connection=False) - elif self._metadata is None: - # no results, get rowcount - # (which requires open cursor on some drivers - # such as kintersbasdb, mxodbc), - self.rowcount - self.close(_autoclose_connection=False) - - return self - - def close(self, _autoclose_connection=True): - """Close this ResultProxy. - - Closes the underlying DBAPI cursor corresponding to the execution. - - Note that any data cached within this ResultProxy is still available. - For some types of results, this may include buffered rows. - - If this ResultProxy was generated from an implicit execution, - the underlying Connection will also be closed (returns the - underlying DBAPI connection to the connection pool.) - - This method is called automatically when: - - * all result rows are exhausted using the fetchXXX() methods. - * cursor.description is None. - - """ - - if not self.closed: - self.closed = True - self.cursor.close() - if _autoclose_connection and \ - self.connection.should_close_with_result: - self.connection.close() - - def __iter__(self): - while True: - row = self.fetchone() - if row is None: - raise StopIteration - else: - yield row - - @util.memoized_property - def inserted_primary_key(self): - """Return the primary key for the row just inserted. - - This only applies to single row insert() constructs which - did not explicitly specify returning(). - - """ - if not self.context.isinsert: - raise exc.InvalidRequestError("Statement is not an insert() expression construct.") - elif self.context._is_explicit_returning: - raise exc.InvalidRequestError("Can't call inserted_primary_key when returning() is used.") - - return self.context._inserted_primary_key - - @util.deprecated("Use inserted_primary_key") - def last_inserted_ids(self): - """deprecated. use inserted_primary_key.""" - - return self.inserted_primary_key - - def last_updated_params(self): - """Return ``last_updated_params()`` from the underlying ExecutionContext. - - See ExecutionContext for details. - """ - - return self.context.last_updated_params() - - def last_inserted_params(self): - """Return ``last_inserted_params()`` from the underlying ExecutionContext. - - See ExecutionContext for details. - """ - - return self.context.last_inserted_params() - - def lastrow_has_defaults(self): - """Return ``lastrow_has_defaults()`` from the underlying ExecutionContext. - - See ExecutionContext for details. - """ - - return self.context.lastrow_has_defaults() - - def postfetch_cols(self): - """Return ``postfetch_cols()`` from the underlying ExecutionContext. - - See ExecutionContext for details. - """ - - return self.context.postfetch_cols - - def prefetch_cols(self): - return self.context.prefetch_cols - - def supports_sane_rowcount(self): - """Return ``supports_sane_rowcount`` from the dialect.""" - - return self.dialect.supports_sane_rowcount - - def supports_sane_multi_rowcount(self): - """Return ``supports_sane_multi_rowcount`` from the dialect.""" - - return self.dialect.supports_sane_multi_rowcount - - def _fetchone_impl(self): - return self.cursor.fetchone() - - def _fetchmany_impl(self, size=None): - return self.cursor.fetchmany(size) - - def _fetchall_impl(self): - return self.cursor.fetchall() - - def process_rows(self, rows): - process_row = self._process_row - metadata = self._metadata - keymap = metadata._keymap - processors = metadata._processors - if self._echo: - log = self.context.engine.logger.debug - l = [] - for row in rows: - log("Row %r", row) - l.append(process_row(metadata, row, processors, keymap)) - return l - else: - return [process_row(metadata, row, processors, keymap) - for row in rows] - - def fetchall(self): - """Fetch all rows, just like DB-API ``cursor.fetchall()``.""" - + def _wrap_pool_connect(self, fn, connection): + dialect = self.dialect try: - l = self.process_rows(self._fetchall_impl()) - self.close() - return l - except Exception, e: - self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context) - raise - - def fetchmany(self, size=None): - """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``. - - If rows are present, the cursor remains open after this is called. - Else the cursor is automatically closed and an empty list is returned. - - """ - - try: - l = self.process_rows(self._fetchmany_impl(size)) - if len(l) == 0: - self.close() - return l - except Exception, e: - self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context) - raise - - def fetchone(self): - """Fetch one row, just like DB-API ``cursor.fetchone()``. - - If a row is present, the cursor remains open after this is called. - Else the cursor is automatically closed and None is returned. - - """ - - try: - row = self._fetchone_impl() - if row is not None: - return self.process_rows([row])[0] + return fn() + except dialect.dbapi.Error as e: + if connection is None: + Connection._handle_dbapi_exception_noconnection( + e, dialect, self) else: - self.close() - return None - except Exception, e: - self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context) - raise + util.reraise(*sys.exc_info()) + + def raw_connection(self, _connection=None): + """Return a "raw" DBAPI connection from the connection pool. + + The returned object is a proxied version of the DBAPI + connection object used by the underlying driver in use. + The object will have all the same behavior as the real DBAPI + connection, except that its ``close()`` method will result in the + connection being returned to the pool, rather than being closed + for real. + + This method provides direct DBAPI connection access for + special situations when the API provided by :class:`.Connection` + is not needed. When a :class:`.Connection` object is already + present, the DBAPI connection is available using + the :attr:`.Connection.connection` accessor. + + .. seealso:: + + :ref:`dbapi_connections` - def first(self): - """Fetch the first row and then close the result set unconditionally. - - Returns None if no row is present. - """ - try: - row = self._fetchone_impl() - except Exception, e: - self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context) - raise + return self._wrap_pool_connect( + self.pool.unique_connection, _connection) - try: - if row is not None: - return self.process_rows([row])[0] - else: - return None - finally: - self.close() - - def scalar(self): - """Fetch the first column of the first row, and close the result set. - - Returns None if no row is present. - - """ - row = self.first() - if row is not None: - return row[0] - else: - return None -class BufferedRowResultProxy(ResultProxy): - """A ResultProxy with row buffering behavior. +class OptionEngine(Engine): + def __init__(self, proxied, execution_options): + self._proxied = proxied + self.url = proxied.url + self.dialect = proxied.dialect + self.logging_name = proxied.logging_name + self.echo = proxied.echo + log.instance_logger(self, echoflag=self.echo) + self.dispatch = self.dispatch._join(proxied.dispatch) + self._execution_options = proxied._execution_options + self.update_execution_options(**execution_options) - ``ResultProxy`` that buffers the contents of a selection of rows - before ``fetchone()`` is called. This is to allow the results of - ``cursor.description`` to be available immediately, when - interfacing with a DB-API that requires rows to be consumed before - this information is available (currently psycopg2, when used with - server-side cursors). + def _get_pool(self): + return self._proxied.pool - The pre-fetching behavior fetches only one row initially, and then - grows its buffer size by a fixed amount with each successive need - for additional rows up to a size of 100. - """ + def _set_pool(self, pool): + self._proxied.pool = pool - def _init_metadata(self): - self.__buffer_rows() - super(BufferedRowResultProxy, self)._init_metadata() + pool = property(_get_pool, _set_pool) - # this is a "growth chart" for the buffering of rows. - # each successive __buffer_rows call will use the next - # value in the list for the buffer size until the max - # is reached - size_growth = { - 1 : 5, - 5 : 10, - 10 : 20, - 20 : 50, - 50 : 100 - } + def _get_has_events(self): + return self._proxied._has_events or \ + self.__dict__.get('_has_events', False) - def __buffer_rows(self): - size = getattr(self, '_bufsize', 1) - self.__rowbuffer = self.cursor.fetchmany(size) - self._bufsize = self.size_growth.get(size, size) + def _set_has_events(self, value): + self.__dict__['_has_events'] = value - def _fetchone_impl(self): - if self.closed: - return None - if len(self.__rowbuffer) == 0: - self.__buffer_rows() - if len(self.__rowbuffer) == 0: - return None - return self.__rowbuffer.pop(0) - - def _fetchmany_impl(self, size=None): - result = [] - for x in range(0, size): - row = self._fetchone_impl() - if row is None: - break - result.append(row) - return result - - def _fetchall_impl(self): - ret = self.__rowbuffer + list(self.cursor.fetchall()) - self.__rowbuffer[:] = [] - return ret - -class FullyBufferedResultProxy(ResultProxy): - """A result proxy that buffers rows fully upon creation. - - Used for operations where a result is to be delivered - after the database conversation can not be continued, - such as MSSQL INSERT...OUTPUT after an autocommit. - - """ - def _init_metadata(self): - super(FullyBufferedResultProxy, self)._init_metadata() - self.__rowbuffer = self._buffer_rows() - - def _buffer_rows(self): - return self.cursor.fetchall() - - def _fetchone_impl(self): - if self.__rowbuffer: - return self.__rowbuffer.pop(0) - else: - return None - - def _fetchmany_impl(self, size=None): - result = [] - for x in range(0, size): - row = self._fetchone_impl() - if row is None: - break - result.append(row) - return result - - def _fetchall_impl(self): - ret = self.__rowbuffer - self.__rowbuffer = [] - return ret - -class BufferedColumnRow(RowProxy): - def __init__(self, parent, row, processors, keymap): - # preprocess row - row = list(row) - # this is a tad faster than using enumerate - index = 0 - for processor in parent._orig_processors: - if processor is not None: - row[index] = processor(row[index]) - index += 1 - row = tuple(row) - super(BufferedColumnRow, self).__init__(parent, row, - processors, keymap) - -class BufferedColumnResultProxy(ResultProxy): - """A ResultProxy with column buffering behavior. - - ``ResultProxy`` that loads all columns into memory each time - fetchone() is called. If fetchmany() or fetchall() are called, - the full grid of results is fetched. This is to operate with - databases where result rows contain "live" results that fall out - of scope unless explicitly fetched. Currently this includes - cx_Oracle LOB objects. - - """ - - _process_row = BufferedColumnRow - - def _init_metadata(self): - super(BufferedColumnResultProxy, self)._init_metadata() - metadata = self._metadata - # orig_processors will be used to preprocess each row when they are - # constructed. - metadata._orig_processors = metadata._processors - # replace the all type processors by None processors. - metadata._processors = [None for _ in xrange(len(metadata.keys))] - keymap = {} - for k, (func, index) in metadata._keymap.iteritems(): - keymap[k] = (None, index) - self._metadata._keymap = keymap - - def fetchall(self): - # can't call cursor.fetchall(), since rows must be - # fully processed before requesting more from the DBAPI. - l = [] - while True: - row = self.fetchone() - if row is None: - break - l.append(row) - return l - - def fetchmany(self, size=None): - # can't call cursor.fetchmany(), since rows must be - # fully processed before requesting more from the DBAPI. - if size is None: - return self.fetchall() - l = [] - for i in xrange(size): - row = self.fetchone() - if row is None: - break - l.append(row) - return l - -def connection_memoize(key): - """Decorator, memoize a function in a connection.info stash. - - Only applicable to functions which take no arguments other than a - connection. The memo will be stored in ``connection.info[key]``. - """ - - @util.decorator - def decorated(fn, self, connection): - connection = connection.connect() - try: - return connection.info[key] - except KeyError: - connection.info[key] = val = fn(self, connection) - return val - - return decorated + _has_events = property(_get_has_events, _set_has_events) diff --git a/sqlalchemy/engine/default.py b/sqlalchemy/engine/default.py index 6fb0a14..bcc78be 100644 --- a/sqlalchemy/engine/default.py +++ b/sqlalchemy/engine/default.py @@ -1,5 +1,6 @@ # engine/default.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -12,16 +13,27 @@ as the base class for their own corresponding classes. """ -import re, random -from sqlalchemy.engine import base, reflection -from sqlalchemy.sql import compiler, expression -from sqlalchemy import exc, types as sqltypes, util +import re +import random +from . import reflection, interfaces, result +from ..sql import compiler, expression, schema +from .. import types as sqltypes +from .. import exc, util, pool, processors +import codecs +import weakref +from .. import event -AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', - re.I | re.UNICODE) +AUTOCOMMIT_REGEXP = re.compile( + r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', + re.I | re.UNICODE) + +# When we're handed literal SQL, ensure it's a SELECT query +SERVER_SIDE_CURSOR_RE = re.compile( + r'\s*SELECT', + re.I | re.UNICODE) -class DefaultDialect(base.Dialect): +class DefaultDialect(interfaces.Dialect): """Default implementation of Dialect""" statement_compiler = compiler.SQLCompiler @@ -30,35 +42,68 @@ class DefaultDialect(base.Dialect): preparer = compiler.IdentifierPreparer supports_alter = True + # the first value we'd get for an autoincrement + # column. + default_sequence_base = 1 + # most DBAPIs happy with this for execute(). - # not cx_oracle. + # not cx_oracle. execute_sequence_format = tuple - + + supports_views = True supports_sequences = False sequences_optional = False preexecute_autoincrement_sequences = False postfetch_lastrowid = True implicit_returning = False - + + supports_right_nested_joins = True + supports_native_enum = False supports_native_boolean = False - + + supports_simple_order_by_label = True + + engine_config_types = util.immutabledict([ + ('convert_unicode', util.bool_or_str('force')), + ('pool_timeout', util.asint), + ('echo', util.bool_or_str('debug')), + ('echo_pool', util.bool_or_str('debug')), + ('pool_recycle', util.asint), + ('pool_size', util.asint), + ('max_overflow', util.asint), + ('pool_threadlocal', util.asbool), + ]) + # if the NUMERIC type # returns decimal.Decimal. # *not* the FLOAT type however. supports_native_decimal = False - - # Py3K - #supports_unicode_statements = True - #supports_unicode_binds = True - # Py2K - supports_unicode_statements = False - supports_unicode_binds = False - returns_unicode_strings = False - # end Py2K + + if util.py3k: + supports_unicode_statements = True + supports_unicode_binds = True + returns_unicode_strings = True + description_encoding = None + else: + supports_unicode_statements = False + supports_unicode_binds = False + returns_unicode_strings = False + description_encoding = 'use_encoding' name = 'default' + + # length at which to truncate + # any identifier. max_identifier_length = 9999 + + # length at which to truncate + # the name of an index. + # Usually None to indicate + # 'use max_identifier_length'. + # thanks to MySQL, sigh + max_index_name_length = None + supports_sane_rowcount = True supports_sane_multi_rowcount = True dbapi_type_map = {} @@ -66,36 +111,81 @@ class DefaultDialect(base.Dialect): default_paramstyle = 'named' supports_default_values = False supports_empty_insert = True - + supports_multivalues_insert = False + + supports_server_side_cursors = False + server_version_info = None - - # indicates symbol names are + + construct_arguments = None + """Optional set of argument specifiers for various SQLAlchemy + constructs, typically schema items. + + To implement, establish as a series of tuples, as in:: + + construct_arguments = [ + (schema.Index, { + "using": False, + "where": None, + "ops": None + }) + ] + + If the above construct is established on the PostgreSQL dialect, + the :class:`.Index` construct will now accept the keyword arguments + ``postgresql_using``, ``postgresql_where``, nad ``postgresql_ops``. + Any other argument specified to the constructor of :class:`.Index` + which is prefixed with ``postgresql_`` will raise :class:`.ArgumentError`. + + A dialect which does not include a ``construct_arguments`` member will + not participate in the argument validation system. For such a dialect, + any argument name is accepted by all participating constructs, within + the namespace of arguments prefixed with that dialect name. The rationale + here is so that third-party dialects that haven't yet implemented this + feature continue to function in the old way. + + .. versionadded:: 0.9.2 + + .. seealso:: + + :class:`.DialectKWArgs` - implementing base class which consumes + :attr:`.DefaultDialect.construct_arguments` + + + """ + + # indicates symbol names are # UPPERCASEd if they are case insensitive # within the database. # if this is True, the methods normalize_name() # and denormalize_name() must be provided. requires_name_normalize = False - + reflection_options = () - def __init__(self, convert_unicode=False, assert_unicode=False, + dbapi_exception_translation_map = util.immutabledict() + """mapping used in the extremely unusual case that a DBAPI's + published exceptions don't actually have the __name__ that they + are linked towards. + + .. versionadded:: 1.0.5 + + """ + + def __init__(self, convert_unicode=False, encoding='utf-8', paramstyle=None, dbapi=None, implicit_returning=None, + supports_right_nested_joins=None, + case_sensitive=True, + supports_native_boolean=None, label_length=None, **kwargs): - + if not getattr(self, 'ported_sqla_06', True): util.warn( - "The %s dialect is not yet ported to SQLAlchemy 0.6" % self.name) - + "The %s dialect is not yet ported to the 0.6 format" % + self.name) + self.convert_unicode = convert_unicode - if assert_unicode: - util.warn_deprecated("assert_unicode is deprecated. " - "SQLAlchemy emits a warning in all cases where it " - "would otherwise like to encode a Python unicode object " - "into a specific encoding but a plain bytestring is received. " - "This does *not* apply to DBAPIs that coerce Unicode natively." - ) - self.encoding = encoding self.positional = False self._ischema = None @@ -111,104 +201,191 @@ class DefaultDialect(base.Dialect): self.positional = self.paramstyle in ('qmark', 'format', 'numeric') self.identifier_preparer = self.preparer(self) self.type_compiler = self.type_compiler(self) + if supports_right_nested_joins is not None: + self.supports_right_nested_joins = supports_right_nested_joins + if supports_native_boolean is not None: + self.supports_native_boolean = supports_native_boolean + self.case_sensitive = case_sensitive if label_length and label_length > self.max_identifier_length: - raise exc.ArgumentError("Label length of %d is greater than this dialect's" - " maximum identifier length of %d" % - (label_length, self.max_identifier_length)) + raise exc.ArgumentError( + "Label length of %d is greater than this dialect's" + " maximum identifier length of %d" % + (label_length, self.max_identifier_length)) self.label_length = label_length - if not hasattr(self, 'description_encoding'): - self.description_encoding = getattr(self, 'description_encoding', encoding) - + if self.description_encoding == 'use_encoding': + self._description_decoder = \ + processors.to_unicode_processor_factory( + encoding + ) + elif self.description_encoding is not None: + self._description_decoder = \ + processors.to_unicode_processor_factory( + self.description_encoding + ) + self._encoder = codecs.getencoder(self.encoding) + self._decoder = processors.to_unicode_processor_factory(self.encoding) + + @util.memoized_property + def _type_memos(self): + return weakref.WeakKeyDictionary() + @property def dialect_description(self): return self.name + "+" + self.driver - + + @classmethod + def get_pool_class(cls, url): + return getattr(cls, 'poolclass', pool.QueuePool) + def initialize(self, connection): try: - self.server_version_info = self._get_server_version_info(connection) + self.server_version_info = \ + self._get_server_version_info(connection) except NotImplementedError: self.server_version_info = None try: - self.default_schema_name = self._get_default_schema_name(connection) + self.default_schema_name = \ + self._get_default_schema_name(connection) except NotImplementedError: self.default_schema_name = None + try: + self.default_isolation_level = \ + self.get_isolation_level(connection.connection) + except NotImplementedError: + self.default_isolation_level = None + self.returns_unicode_strings = self._check_unicode_returns(connection) - + + if self.description_encoding is not None and \ + self._check_unicode_description(connection): + self._description_decoder = self.description_encoding = None + self.do_rollback(connection.connection) - + def on_connect(self): """return a callable which sets up a newly created DBAPI connection. - - This is used to set dialect-wide per-connection options such as isolation - modes, unicode modes, etc. - + + This is used to set dialect-wide per-connection options such as + isolation modes, unicode modes, etc. + If a callable is returned, it will be assembled into a pool listener that receives the direct DBAPI connection, with all wrappers removed. - + If None is returned, no listener will be generated. - + """ return None - - def _check_unicode_returns(self, connection): - # Py2K - if self.supports_unicode_statements: - cast_to = unicode + + def _check_unicode_returns(self, connection, additional_tests=None): + if util.py2k and not self.supports_unicode_statements: + cast_to = util.binary_type else: - cast_to = str - # end Py2K - # Py3K - #cast_to = str - def check_unicode(type_): - cursor = connection.connection.cursor() + cast_to = util.text_type + + if self.positional: + parameters = self.execute_sequence_format() + else: + parameters = {} + + def check_unicode(test): + statement = cast_to( + expression.select([test]).compile(dialect=self)) try: - cursor.execute( - cast_to( - expression.select( - [expression.cast( - expression.literal_column("'test unicode returns'"), type_) - ]).compile(dialect=self) - ) - ) + cursor = connection.connection.cursor() + connection._cursor_execute(cursor, statement, parameters) row = cursor.fetchone() - - return isinstance(row[0], unicode) - finally: cursor.close() - - # detect plain VARCHAR - unicode_for_varchar = check_unicode(sqltypes.VARCHAR(60)) - - # detect if there's an NVARCHAR type with different behavior available - unicode_for_unicode = check_unicode(sqltypes.Unicode(60)) - - if unicode_for_unicode and not unicode_for_varchar: + except exc.DBAPIError as de: + # note that _cursor_execute() will have closed the cursor + # if an exception is thrown. + util.warn("Exception attempting to " + "detect unicode returns: %r" % de) + return False + else: + return isinstance(row[0], util.text_type) + + tests = [ + # detect plain VARCHAR + expression.cast( + expression.literal_column("'test plain returns'"), + sqltypes.VARCHAR(60) + ), + # detect if there's an NVARCHAR type with different behavior + # available + expression.cast( + expression.literal_column("'test unicode returns'"), + sqltypes.Unicode(60) + ), + ] + + if additional_tests: + tests += additional_tests + + results = set([check_unicode(test) for test in tests]) + + if results.issuperset([True, False]): return "conditional" else: - return unicode_for_varchar - + return results == set([True]) + + def _check_unicode_description(self, connection): + # all DBAPIs on Py2K return cursor.description as encoded, + # until pypy2.1beta2 with sqlite, so let's just check it - + # it's likely others will start doing this too in Py2k. + + if util.py2k and not self.supports_unicode_statements: + cast_to = util.binary_type + else: + cast_to = util.text_type + + cursor = connection.connection.cursor() + try: + cursor.execute( + cast_to( + expression.select([ + expression.literal_column("'x'").label("some_label") + ]).compile(dialect=self) + ) + ) + return isinstance(cursor.description[0][0], util.text_type) + finally: + cursor.close() + def type_descriptor(self, typeobj): - """Provide a database-specific ``TypeEngine`` object, given + """Provide a database-specific :class:`.TypeEngine` object, given the generic object which comes from the types module. This method looks for a dictionary called ``colspecs`` as a class or instance-level variable, - and passes on to ``types.adapt_type()``. + and passes on to :func:`.types.adapt_type`. """ return sqltypes.adapt_type(typeobj, self.colspecs) - def reflecttable(self, connection, table, include_columns): + def reflecttable( + self, connection, table, include_columns, exclude_columns, **opts): insp = reflection.Inspector.from_engine(connection) - return insp.reflecttable(table, include_columns) + return insp.reflecttable( + table, include_columns, exclude_columns, **opts) + + def get_pk_constraint(self, conn, table_name, schema=None, **kw): + """Compatibility method, adapts the result of get_primary_keys() + for those dialects which don't implement get_pk_constraint(). + + """ + return { + 'constrained_columns': + self.get_primary_keys(conn, table_name, + schema=schema, **kw) + } def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: raise exc.IdentifierError( - "Identifier '%s' exceeds maximum length of %d characters" % + "Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length) ) @@ -220,26 +397,53 @@ class DefaultDialect(base.Dialect): opts.update(url.query) return [[], opts] - def do_begin(self, connection): - """Implementations might want to put logic here for turning - autocommit on/off, etc. - """ + def set_engine_execution_options(self, engine, opts): + if 'isolation_level' in opts: + isolation_level = opts['isolation_level'] + @event.listens_for(engine, "engine_connect") + def set_isolation(connection, branch): + if not branch: + self._set_connection_isolation(connection, isolation_level) + + if 'schema_translate_map' in opts: + getter = schema._schema_getter(opts['schema_translate_map']) + engine.schema_for_object = getter + + @event.listens_for(engine, "engine_connect") + def set_schema_translate_map(connection, branch): + connection.schema_for_object = getter + + def set_connection_execution_options(self, connection, opts): + if 'isolation_level' in opts: + self._set_connection_isolation(connection, opts['isolation_level']) + + if 'schema_translate_map' in opts: + getter = schema._schema_getter(opts['schema_translate_map']) + connection.schema_for_object = getter + + def _set_connection_isolation(self, connection, level): + if connection.in_transaction(): + util.warn( + "Connection is already established with a Transaction; " + "setting isolation_level may implicitly rollback or commit " + "the existing transaction, or have no effect until " + "next transaction") + self.set_isolation_level(connection.connection, level) + connection.connection._connection_record.\ + finalize_callback.append(self.reset_isolation_level) + + def do_begin(self, dbapi_connection): pass - def do_rollback(self, connection): - """Implementations might want to put logic here for turning - autocommit on/off, etc. - """ + def do_rollback(self, dbapi_connection): + dbapi_connection.rollback() - connection.rollback() + def do_commit(self, dbapi_connection): + dbapi_connection.commit() - def do_commit(self, connection): - """Implementations might want to put logic here for turning - autocommit on/off, etc. - """ - - connection.commit() + def do_close(self, dbapi_connection): + dbapi_connection.close() def create_xid(self): """Create a random two-phase transaction ID. @@ -265,281 +469,413 @@ class DefaultDialect(base.Dialect): def do_execute(self, cursor, statement, parameters, context=None): cursor.execute(statement, parameters) - def is_disconnect(self, e): + def do_execute_no_params(self, cursor, statement, context=None): + cursor.execute(statement) + + def is_disconnect(self, e, connection, cursor): return False + def reset_isolation_level(self, dbapi_conn): + # default_isolation_level is read from the first connection + # after the initial set of 'isolation_level', if any, so is + # the configured default of this dialect. + self.set_isolation_level(dbapi_conn, self.default_isolation_level) -class DefaultExecutionContext(base.ExecutionContext): - execution_options = util.frozendict() + +class StrCompileDialect(DefaultDialect): + + statement_compiler = compiler.StrSQLCompiler + ddl_compiler = compiler.DDLCompiler + type_compiler = compiler.StrSQLTypeCompiler + preparer = compiler.IdentifierPreparer + + supports_sequences = True + sequences_optional = True + preexecute_autoincrement_sequences = False + implicit_returning = False + + supports_native_boolean = True + + supports_simple_order_by_label = True + + +class DefaultExecutionContext(interfaces.ExecutionContext): isinsert = False isupdate = False isdelete = False + is_crud = False + is_text = False isddl = False executemany = False - result_map = None compiled = None statement = None - - def __init__(self, - dialect, - connection, - compiled_sql=None, - compiled_ddl=None, - statement=None, - parameters=None): - - self.dialect = dialect - self._connection = self.root_connection = connection - self.engine = connection.engine - - if compiled_ddl is not None: - self.compiled = compiled = compiled_ddl - self.isddl = True + result_column_struct = None + returned_defaults = None + _is_implicit_returning = False + _is_explicit_returning = False - if compiled.statement._execution_options: - self.execution_options = compiled.statement._execution_options - if connection._execution_options: - self.execution_options = self.execution_options.union( - connection._execution_options - ) + # a hook for SQLite's translation of + # result column names + _translate_colname = None - if not dialect.supports_unicode_statements: - self.unicode_statement = unicode(compiled) - self.statement = self.unicode_statement.encode(self.dialect.encoding) - else: - self.statement = self.unicode_statement = unicode(compiled) - - self.cursor = self.create_cursor() - self.compiled_parameters = [] - self.parameters = [self._default_params] - - elif compiled_sql is not None: - self.compiled = compiled = compiled_sql + @classmethod + def _init_ddl(cls, dialect, connection, dbapi_connection, compiled_ddl): + """Initialize execution context for a DDLElement construct.""" - if not compiled.can_execute: - raise exc.ArgumentError("Not an executable clause: %s" % compiled) + self = cls.__new__(cls) + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.dialect = connection.dialect - if compiled.statement._execution_options: - self.execution_options = compiled.statement._execution_options - if connection._execution_options: - self.execution_options = self.execution_options.union( - connection._execution_options - ) + self.compiled = compiled = compiled_ddl + self.isddl = True - # compiled clauseelement. process bind params, process table defaults, - # track collections used by ResultProxy to target and process results + self.execution_options = compiled.execution_options + if connection._execution_options: + self.execution_options = dict(self.execution_options) + self.execution_options.update(connection._execution_options) - self.processors = dict( - (key, value) for key, value in - ( (compiled.bind_names[bindparam], - bindparam.bind_processor(self.dialect)) - for bindparam in compiled.bind_names ) - if value is not None) + if not dialect.supports_unicode_statements: + self.unicode_statement = util.text_type(compiled) + self.statement = dialect._encoder(self.unicode_statement)[0] + else: + self.statement = self.unicode_statement = util.text_type(compiled) - self.result_map = compiled.result_map + self.cursor = self.create_cursor() + self.compiled_parameters = [] - if not dialect.supports_unicode_statements: - self.unicode_statement = unicode(compiled) - self.statement = self.unicode_statement.encode(self.dialect.encoding) - else: - self.statement = self.unicode_statement = unicode(compiled) + if dialect.positional: + self.parameters = [dialect.execute_sequence_format()] + else: + self.parameters = [{}] - self.isinsert = compiled.isinsert - self.isupdate = compiled.isupdate - self.isdelete = compiled.isdelete + return self - if not parameters: - self.compiled_parameters = [compiled.construct_params()] - else: - self.compiled_parameters = [compiled.construct_params(m, _group_number=grp) for - grp,m in enumerate(parameters)] - - self.executemany = len(parameters) > 1 + @classmethod + def _init_compiled(cls, dialect, connection, dbapi_connection, + compiled, parameters): + """Initialize execution context for a Compiled construct.""" + + self = cls.__new__(cls) + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.dialect = connection.dialect + + self.compiled = compiled + + # this should be caught in the engine before + # we get here + assert compiled.can_execute + + self.execution_options = compiled.execution_options.union( + connection._execution_options) + + self.result_column_struct = ( + compiled._result_columns, compiled._ordered_columns, + compiled._textual_ordered_columns) + + self.unicode_statement = util.text_type(compiled) + if not dialect.supports_unicode_statements: + self.statement = self.unicode_statement.encode( + self.dialect.encoding) + else: + self.statement = self.unicode_statement + + self.isinsert = compiled.isinsert + self.isupdate = compiled.isupdate + self.isdelete = compiled.isdelete + self.is_text = compiled.isplaintext + + if not parameters: + self.compiled_parameters = [compiled.construct_params()] + else: + self.compiled_parameters = \ + [compiled.construct_params(m, _group_number=grp) for + grp, m in enumerate(parameters)] - self.cursor = self.create_cursor() - if self.isinsert or self.isupdate: - self.__process_defaults() - self.parameters = self.__convert_compiled_params(self.compiled_parameters) - - elif statement is not None: - # plain text statement - if connection._execution_options: - self.execution_options = self.execution_options.union(connection._execution_options) - self.parameters = self.__encode_param_keys(parameters) self.executemany = len(parameters) > 1 - - if isinstance(statement, unicode) and not dialect.supports_unicode_statements: - self.unicode_statement = statement - self.statement = statement.encode(self.dialect.encoding) + + self.cursor = self.create_cursor() + + if self.isinsert or self.isupdate or self.isdelete: + self.is_crud = True + self._is_explicit_returning = bool(compiled.statement._returning) + self._is_implicit_returning = bool( + compiled.returning and not compiled.statement._returning) + + if self.compiled.insert_prefetch or self.compiled.update_prefetch: + if self.executemany: + self._process_executemany_defaults() else: - self.statement = self.unicode_statement = statement - - self.cursor = self.create_cursor() - else: - # no statement. used for standalone ColumnDefault execution. - if connection._execution_options: - self.execution_options = self.execution_options.union(connection._execution_options) - self.cursor = self.create_cursor() - - @util.memoized_property - def is_crud(self): - return self.isinsert or self.isupdate or self.isdelete - - @util.memoized_property - def should_autocommit(self): - autocommit = self.execution_options.get('autocommit', - not self.compiled and - self.statement and - expression.PARSE_AUTOCOMMIT - or False) - - if autocommit is expression.PARSE_AUTOCOMMIT: - return self.should_autocommit_text(self.unicode_statement) - else: - return autocommit - - @util.memoized_property - def _is_explicit_returning(self): - return self.compiled and \ - getattr(self.compiled.statement, '_returning', False) - - @util.memoized_property - def _is_implicit_returning(self): - return self.compiled and \ - bool(self.compiled.returning) and \ - not self.compiled.statement._returning - - @util.memoized_property - def _default_params(self): - if self.dialect.positional: - return self.dialect.execute_sequence_format() - else: - return {} - - def _execute_scalar(self, stmt): - """Execute a string statement on the current cursor, returning a scalar result. - - Used to fire off sequences, default phrases, and "select lastrowid" - types of statements individually - or in the context of a parent INSERT or UPDATE statement. - - """ + self._process_executesingle_defaults() - conn = self._connection - if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements: - stmt = stmt.encode(self.dialect.encoding) - conn._cursor_execute(self.cursor, stmt, self._default_params) - return self.cursor.fetchone()[0] - - @property - def connection(self): - return self._connection._branch() + processors = compiled._bind_processors - def __encode_param_keys(self, params): - """Apply string encoding to the keys of dictionary-based bind parameters. - - This is only used executing textual, non-compiled SQL expressions. - - """ - - if not params: - return [self._default_params] - elif isinstance(params[0], self.dialect.execute_sequence_format): - return params - elif isinstance(params[0], dict): - if self.dialect.supports_unicode_statements: - return params - else: - def proc(d): - return dict((k.encode(self.dialect.encoding), d[k]) for k in d) - return [proc(d) for d in params] or [{}] - else: - return [self.dialect.execute_sequence_format(p) for p in params] - - - def __convert_compiled_params(self, compiled_parameters): - """Convert the dictionary of bind parameter values into a dict or list - to be sent to the DBAPI's execute() or executemany() method. - """ - - processors = self.processors + # Convert the dictionary of bind parameter values + # into a dict or list to be sent to the DBAPI's + # execute() or executemany() method. parameters = [] - if self.dialect.positional: - for compiled_params in compiled_parameters: + if dialect.positional: + for compiled_params in self.compiled_parameters: param = [] for key in self.compiled.positiontup: if key in processors: param.append(processors[key](compiled_params[key])) else: param.append(compiled_params[key]) - parameters.append(self.dialect.execute_sequence_format(param)) + parameters.append(dialect.execute_sequence_format(param)) else: - encode = not self.dialect.supports_unicode_statements - for compiled_params in compiled_parameters: - param = {} + encode = not dialect.supports_unicode_statements + for compiled_params in self.compiled_parameters: + if encode: - encoding = self.dialect.encoding - for key in compiled_params: - if key in processors: - param[key.encode(encoding)] = processors[key](compiled_params[key]) - else: - param[key.encode(encoding)] = compiled_params[key] + param = dict( + ( + dialect._encoder(key)[0], + processors[key](compiled_params[key]) + if key in processors + else compiled_params[key] + ) + for key in compiled_params + ) else: - for key in compiled_params: - if key in processors: - param[key] = processors[key](compiled_params[key]) - else: - param[key] = compiled_params[key] + param = dict( + ( + key, + processors[key](compiled_params[key]) + if key in processors + else compiled_params[key] + ) + for key in compiled_params + ) + parameters.append(param) - return self.dialect.execute_sequence_format(parameters) + self.parameters = dialect.execute_sequence_format(parameters) + + return self + + @classmethod + def _init_statement(cls, dialect, connection, dbapi_connection, + statement, parameters): + """Initialize execution context for a string SQL statement.""" + + self = cls.__new__(cls) + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.dialect = connection.dialect + self.is_text = True + + # plain text statement + self.execution_options = connection._execution_options + + if not parameters: + if self.dialect.positional: + self.parameters = [dialect.execute_sequence_format()] + else: + self.parameters = [{}] + elif isinstance(parameters[0], dialect.execute_sequence_format): + self.parameters = parameters + elif isinstance(parameters[0], dict): + if dialect.supports_unicode_statements: + self.parameters = parameters + else: + self.parameters = [ + dict((dialect._encoder(k)[0], d[k]) for k in d) + for d in parameters + ] or [{}] + else: + self.parameters = [dialect.execute_sequence_format(p) + for p in parameters] + + self.executemany = len(parameters) > 1 + + if not dialect.supports_unicode_statements and \ + isinstance(statement, util.text_type): + self.unicode_statement = statement + self.statement = dialect._encoder(statement)[0] + else: + self.statement = self.unicode_statement = statement + + self.cursor = self.create_cursor() + return self + + @classmethod + def _init_default(cls, dialect, connection, dbapi_connection): + """Initialize execution context for a ColumnDefault construct.""" + + self = cls.__new__(cls) + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.dialect = connection.dialect + self.execution_options = connection._execution_options + self.cursor = self.create_cursor() + return self + + @util.memoized_property + def engine(self): + return self.root_connection.engine + + @util.memoized_property + def postfetch_cols(self): + return self.compiled.postfetch + + @util.memoized_property + def prefetch_cols(self): + if self.isinsert: + return self.compiled.insert_prefetch + elif self.isupdate: + return self.compiled.update_prefetch + else: + return () + + @util.memoized_property + def returning_cols(self): + self.compiled.returning + + @util.memoized_property + def no_parameters(self): + return self.execution_options.get("no_parameters", False) + + @util.memoized_property + def should_autocommit(self): + autocommit = self.execution_options.get('autocommit', + not self.compiled and + self.statement and + expression.PARSE_AUTOCOMMIT + or False) + + if autocommit is expression.PARSE_AUTOCOMMIT: + return self.should_autocommit_text(self.unicode_statement) + else: + return autocommit + + def _execute_scalar(self, stmt, type_): + """Execute a string statement on the current cursor, returning a + scalar result. + + Used to fire off sequences, default phrases, and "select lastrowid" + types of statements individually or in the context of a parent INSERT + or UPDATE statement. + + """ + + conn = self.root_connection + if isinstance(stmt, util.text_type) and \ + not self.dialect.supports_unicode_statements: + stmt = self.dialect._encoder(stmt)[0] + + if self.dialect.positional: + default_params = self.dialect.execute_sequence_format() + else: + default_params = {} + + conn._cursor_execute(self.cursor, stmt, default_params, context=self) + r = self.cursor.fetchone()[0] + if type_ is not None: + # apply type post processors to the result + proc = type_._cached_result_processor( + self.dialect, + self.cursor.description[0][1] + ) + if proc: + return proc(r) + return r + + @property + def connection(self): + return self.root_connection._branch() def should_autocommit_text(self, statement): return AUTOCOMMIT_REGEXP.match(statement) + def _use_server_side_cursor(self): + if not self.dialect.supports_server_side_cursors: + return False + + if self.dialect.server_side_cursors: + use_server_side = \ + self.execution_options.get('stream_results', True) and ( + (self.compiled and isinstance(self.compiled.statement, + expression.Selectable) + or + ( + (not self.compiled or + isinstance(self.compiled.statement, + expression.TextClause)) + and self.statement and SERVER_SIDE_CURSOR_RE.match( + self.statement)) + ) + ) + else: + use_server_side = \ + self.execution_options.get('stream_results', False) + + return use_server_side + def create_cursor(self): - return self._connection.connection.cursor() + if self._use_server_side_cursor(): + self._is_server_side = True + return self.create_server_side_cursor() + else: + self._is_server_side = False + return self._dbapi_connection.cursor() + + def create_server_side_cursor(self): + raise NotImplementedError() def pre_exec(self): pass def post_exec(self): pass - + + def get_result_processor(self, type_, colname, coltype): + """Return a 'result processor' for a given type as present in + cursor.description. + + This has a default implementation that dialects can override + for context-sensitive result type handling. + + """ + return type_._cached_result_processor(self.dialect, coltype) + def get_lastrowid(self): """return self.cursor.lastrowid, or equivalent, after an INSERT. - + This may involve calling special cursor functions, issuing a new SELECT on the cursor (or a new one), or returning a stored value that was calculated within post_exec(). - + This function will only be called for dialects which support "implicit" primary key generation, keep preexecute_autoincrement_sequences set to False, and when no explicit id value was bound to the statement. - - The function is called once, directly after + + The function is called once, directly after post_exec() and before the transaction is committed or ResultProxy is generated. If the post_exec() method assigns a value to `self._lastrowid`, the value is used in place of calling get_lastrowid(). - + Note that this method is *not* equivalent to the ``lastrowid`` method on ``ResultProxy``, which is a direct proxy to the DBAPI ``lastrowid`` accessor in all cases. - + """ - return self.cursor.lastrowid def handle_dbapi_exception(self, e): pass def get_result_proxy(self): - return base.ResultProxy(self) - + if self._is_server_side: + return result.BufferedRowResultProxy(self) + else: + return result.ResultProxy(self) + @property def rowcount(self): return self.cursor.rowcount @@ -549,152 +885,235 @@ class DefaultExecutionContext(base.ExecutionContext): def supports_sane_multi_rowcount(self): return self.dialect.supports_sane_multi_rowcount - - def post_insert(self): - if self.dialect.postfetch_lastrowid and \ - (not len(self._inserted_primary_key) or \ - None in self._inserted_primary_key): - - table = self.compiled.statement.table - lastrowid = self.get_lastrowid() - self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v - for c, v in zip(table.primary_key, self._inserted_primary_key) - ] - - def _fetch_implicit_returning(self, resultproxy): - table = self.compiled.statement.table - row = resultproxy.fetchone() - self._inserted_primary_key = [v is not None and v or row[c] - for c, v in zip(table.primary_key, self._inserted_primary_key) + def _setup_crud_result_proxy(self): + if self.isinsert and \ + not self.executemany: + if not self._is_implicit_returning and \ + not self.compiled.inline and \ + self.dialect.postfetch_lastrowid: + + self._setup_ins_pk_from_lastrowid() + + elif not self._is_implicit_returning: + self._setup_ins_pk_from_empty() + + result = self.get_result_proxy() + + if self.isinsert: + if self._is_implicit_returning: + row = result.fetchone() + self.returned_defaults = row + self._setup_ins_pk_from_implicit_returning(row) + result._soft_close() + result._metadata = None + elif not self._is_explicit_returning: + result._soft_close() + result._metadata = None + elif self.isupdate and self._is_implicit_returning: + row = result.fetchone() + self.returned_defaults = row + result._soft_close() + result._metadata = None + + elif result._metadata is None: + # no results, get rowcount + # (which requires open cursor on some drivers + # such as kintersbasdb, mxodbc) + result.rowcount + result._soft_close() + return result + + def _setup_ins_pk_from_lastrowid(self): + key_getter = self.compiled._key_getters_for_crud_column[2] + table = self.compiled.statement.table + compiled_params = self.compiled_parameters[0] + + lastrowid = self.get_lastrowid() + if lastrowid is not None: + autoinc_col = table._autoincrement_column + if autoinc_col is not None: + # apply type post processors to the lastrowid + proc = autoinc_col.type._cached_result_processor( + self.dialect, None) + if proc is not None: + lastrowid = proc(lastrowid) + self.inserted_primary_key = [ + lastrowid if c is autoinc_col else + compiled_params.get(key_getter(c), None) + for c in table.primary_key + ] + else: + # don't have a usable lastrowid, so + # do the same as _setup_ins_pk_from_empty + self.inserted_primary_key = [ + compiled_params.get(key_getter(c), None) + for c in table.primary_key + ] + + def _setup_ins_pk_from_empty(self): + key_getter = self.compiled._key_getters_for_crud_column[2] + table = self.compiled.statement.table + compiled_params = self.compiled_parameters[0] + self.inserted_primary_key = [ + compiled_params.get(key_getter(c), None) + for c in table.primary_key ] - def last_inserted_params(self): - return self._last_inserted_params + def _setup_ins_pk_from_implicit_returning(self, row): + if row is None: + self.inserted_primary_key = None + return - def last_updated_params(self): - return self._last_updated_params + key_getter = self.compiled._key_getters_for_crud_column[2] + table = self.compiled.statement.table + compiled_params = self.compiled_parameters[0] + self.inserted_primary_key = [ + row[col] if value is None else value + for col, value in [ + (col, compiled_params.get(key_getter(col), None)) + for col in table.primary_key + ] + ] def lastrow_has_defaults(self): - return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols) + return (self.isinsert or self.isupdate) and \ + bool(self.compiled.postfetch) def set_input_sizes(self, translate=None, exclude_types=None): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. + + This method only called by those dialects which require it, + currently cx_oracle. + """ if not hasattr(self.compiled, 'bind_names'): return types = dict( - (self.compiled.bind_names[bindparam], bindparam.type) - for bindparam in self.compiled.bind_names) + (self.compiled.bind_names[bindparam], bindparam.type) + for bindparam in self.compiled.bind_names) if self.dialect.positional: inputsizes = [] for key in self.compiled.positiontup: typeengine = types[key] - dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None and (not exclude_types or dbtype not in exclude_types): + dbtype = typeengine.dialect_impl(self.dialect).\ + get_dbapi_type(self.dialect.dbapi) + if dbtype is not None and \ + (not exclude_types or dbtype not in exclude_types): inputsizes.append(dbtype) try: self.cursor.setinputsizes(*inputsizes) - except Exception, e: - self._connection._handle_dbapi_exception(e, None, None, None, self) - raise + except BaseException as e: + self.root_connection._handle_dbapi_exception( + e, None, None, None, self) else: inputsizes = {} for key in self.compiled.bind_names.values(): typeengine = types[key] - dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None and (not exclude_types or dbtype not in exclude_types): + dbtype = typeengine.dialect_impl(self.dialect).\ + get_dbapi_type(self.dialect.dbapi) + if dbtype is not None and \ + (not exclude_types or dbtype not in exclude_types): if translate: key = translate.get(key, key) - inputsizes[key.encode(self.dialect.encoding)] = dbtype + if not self.dialect.supports_unicode_binds: + key = self.dialect._encoder(key)[0] + inputsizes[key] = dbtype try: self.cursor.setinputsizes(**inputsizes) - except Exception, e: - self._connection._handle_dbapi_exception(e, None, None, None, self) - raise + except BaseException as e: + self.root_connection._handle_dbapi_exception( + e, None, None, None, self) - def _exec_default(self, default): + def _exec_default(self, default, type_): if default.is_sequence: - return self.fire_sequence(default) + return self.fire_sequence(default, type_) elif default.is_callable: return default.arg(self) elif default.is_clause_element: - # TODO: expensive branching here should be + # TODO: expensive branching here should be # pulled into _exec_scalar() - conn = self.connection + conn = self.connection c = expression.select([default.arg]).compile(bind=conn) return conn._execute_compiled(c, (), {}).scalar() else: return default.arg - + def get_insert_default(self, column): if column.default is None: return None else: - return self._exec_default(column.default) + return self._exec_default(column.default, column.type) def get_update_default(self, column): if column.onupdate is None: return None else: - return self._exec_default(column.onupdate) + return self._exec_default(column.onupdate, column.type) - def __process_defaults(self): - """Generate default values for compiled insert/update statements, - and generate inserted_primary_key collection. - """ + def _process_executemany_defaults(self): + key_getter = self.compiled._key_getters_for_crud_column[2] - if self.executemany: - if len(self.compiled.prefetch): - scalar_defaults = {} - - # pre-determine scalar Python-side defaults - # to avoid many calls of get_insert_default()/get_update_default() - for c in self.compiled.prefetch: - if self.isinsert and c.default and c.default.is_scalar: - scalar_defaults[c] = c.default.arg - elif self.isupdate and c.onupdate and c.onupdate.is_scalar: - scalar_defaults[c] = c.onupdate.arg - - for param in self.compiled_parameters: - self.current_parameters = param - for c in self.compiled.prefetch: - if c in scalar_defaults: - val = scalar_defaults[c] - elif self.isinsert: - val = self.get_insert_default(c) - else: - val = self.get_update_default(c) - if val is not None: - param[c.key] = val - del self.current_parameters + scalar_defaults = {} - else: - self.current_parameters = compiled_parameters = self.compiled_parameters[0] + insert_prefetch = self.compiled.insert_prefetch + update_prefetch = self.compiled.update_prefetch - for c in self.compiled.prefetch: - if self.isinsert: + # pre-determine scalar Python-side defaults + # to avoid many calls of get_insert_default()/ + # get_update_default() + for c in insert_prefetch: + if c.default and c.default.is_scalar: + scalar_defaults[c] = c.default.arg + for c in update_prefetch: + if c.onupdate and c.onupdate.is_scalar: + scalar_defaults[c] = c.onupdate.arg + + for param in self.compiled_parameters: + self.current_parameters = param + for c in insert_prefetch: + if c in scalar_defaults: + val = scalar_defaults[c] + else: val = self.get_insert_default(c) + if val is not None: + param[key_getter(c)] = val + for c in update_prefetch: + if c in scalar_defaults: + val = scalar_defaults[c] else: val = self.get_update_default(c) - if val is not None: - compiled_parameters[c.key] = val - del self.current_parameters + param[key_getter(c)] = val - if self.isinsert: - self._inserted_primary_key = [compiled_parameters.get(c.key, None) - for c in self.compiled.statement.table.primary_key] - self._last_inserted_params = compiled_parameters + del self.current_parameters + + def _process_executesingle_defaults(self): + key_getter = self.compiled._key_getters_for_crud_column[2] + self.current_parameters = compiled_parameters = \ + self.compiled_parameters[0] + + for c in self.compiled.insert_prefetch: + if c.default and \ + not c.default.is_sequence and c.default.is_scalar: + val = c.default.arg else: - self._last_updated_params = compiled_parameters + val = self.get_insert_default(c) + + if val is not None: + compiled_parameters[key_getter(c)] = val + + for c in self.compiled.update_prefetch: + val = self.get_update_default(c) + + if val is not None: + compiled_parameters[key_getter(c)] = val + del self.current_parameters + - self.postfetch_cols = self.compiled.postfetch - self.prefetch_cols = self.compiled.prefetch - DefaultDialect.execution_ctx_cls = DefaultExecutionContext diff --git a/sqlalchemy/engine/reflection.py b/sqlalchemy/engine/reflection.py index 57f2205..dfa81f4 100644 --- a/sqlalchemy/engine/reflection.py +++ b/sqlalchemy/engine/reflection.py @@ -1,3 +1,10 @@ +# engine/reflection.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + """Provides an abstraction for obtaining database schema information. Usage Notes: @@ -18,11 +25,14 @@ methods such as get_table_names, get_columns, etc. 'name' attribute.. """ -import sqlalchemy -from sqlalchemy import exc, sql -from sqlalchemy import util -from sqlalchemy.types import TypeEngine -from sqlalchemy import schema as sa_schema +from .. import exc, sql +from ..sql import schema as sa_schema +from .. import util +from ..sql.type_api import TypeEngine +from ..util import deprecated +from ..util import topological +from .. import inspection +from .base import Connectable @util.decorator @@ -31,10 +41,14 @@ def cache(fn, self, con, *args, **kw): if info_cache is None: return fn(self, con, *args, **kw) key = ( - fn.__name__, - tuple(a for a in args if isinstance(a, basestring)), - tuple((k, v) for k, v in kw.iteritems() if isinstance(v, (basestring, int, float))) - ) + fn.__name__, + tuple(a for a in args if isinstance(a, util.string_types)), + tuple((k, v) for k, v in kw.items() if + isinstance(v, + util.string_types + util.int_types + (float, ) + ) + ) + ) ret = info_cache.get(key) if ret is None: ret = fn(self, con, *args, **kw) @@ -45,33 +59,94 @@ def cache(fn, self, con, *args, **kw): class Inspector(object): """Performs database schema inspection. - The Inspector acts as a proxy to the dialects' reflection methods and - provides higher level functions for accessing database schema information. + The Inspector acts as a proxy to the reflection methods of the + :class:`~sqlalchemy.engine.interfaces.Dialect`, providing a + consistent interface as well as caching support for previously + fetched metadata. + + A :class:`.Inspector` object is usually created via the + :func:`.inspect` function:: + + from sqlalchemy import inspect, create_engine + engine = create_engine('...') + insp = inspect(engine) + + The inspection method above is equivalent to using the + :meth:`.Inspector.from_engine` method, i.e.:: + + engine = create_engine('...') + insp = Inspector.from_engine(engine) + + Where above, the :class:`~sqlalchemy.engine.interfaces.Dialect` may opt + to return an :class:`.Inspector` subclass that provides additional + methods specific to the dialect's target database. + """ - def __init__(self, conn): - """Initialize the instance. + def __init__(self, bind): + """Initialize a new :class:`.Inspector`. + + :param bind: a :class:`~sqlalchemy.engine.Connectable`, + which is typically an instance of + :class:`~sqlalchemy.engine.Engine` or + :class:`~sqlalchemy.engine.Connection`. + + For a dialect-specific instance of :class:`.Inspector`, see + :meth:`.Inspector.from_engine` - :param conn: a :class:`~sqlalchemy.engine.base.Connectable` """ + # this might not be a connection, it could be an engine. + self.bind = bind - self.conn = conn # set the engine - if hasattr(conn, 'engine'): - self.engine = conn.engine + if hasattr(bind, 'engine'): + self.engine = bind.engine else: - self.engine = conn + self.engine = bind + + if self.engine is bind: + # if engine, ensure initialized + bind.connect().close() + self.dialect = self.engine.dialect self.info_cache = {} @classmethod - def from_engine(cls, engine): - if hasattr(engine.dialect, 'inspector'): - return engine.dialect.inspector(engine) - return Inspector(engine) + def from_engine(cls, bind): + """Construct a new dialect-specific Inspector object from the given + engine or connection. + + :param bind: a :class:`~sqlalchemy.engine.Connectable`, + which is typically an instance of + :class:`~sqlalchemy.engine.Engine` or + :class:`~sqlalchemy.engine.Connection`. + + This method differs from direct a direct constructor call of + :class:`.Inspector` in that the + :class:`~sqlalchemy.engine.interfaces.Dialect` is given a chance to + provide a dialect-specific :class:`.Inspector` instance, which may + provide additional methods. + + See the example at :class:`.Inspector`. + + """ + if hasattr(bind.dialect, 'inspector'): + return bind.dialect.inspector(bind) + return Inspector(bind) + + @inspection._inspects(Connectable) + def _insp(bind): + return Inspector.from_engine(bind) @property def default_schema_name(self): + """Return the default schema name presented by the dialect + for the current engine's database user. + + E.g. this is typically ``public`` for PostgreSQL and ``dbo`` + for SQL Server. + + """ return self.dialect.default_schema_name def get_schema_names(self): @@ -79,70 +154,185 @@ class Inspector(object): """ if hasattr(self.dialect, 'get_schema_names'): - return self.dialect.get_schema_names(self.conn, - info_cache=self.info_cache) + return self.dialect.get_schema_names(self.bind, + info_cache=self.info_cache) return [] def get_table_names(self, schema=None, order_by=None): - """Return all table names in `schema`. + """Return all table names in referred to within a particular schema. + + The names are expected to be real tables only, not views. + Views are instead returned using the :meth:`.Inspector.get_view_names` + method. + + + :param schema: Schema name. If ``schema`` is left at ``None``, the + database's default schema is + used, else the named schema is searched. If the database does not + support named schemas, behavior is undefined if ``schema`` is not + passed as ``None``. For special quoting, use :class:`.quoted_name`. - :param schema: Optional, retrieve names from a non-default schema. :param order_by: Optional, may be the string "foreign_key" to sort - the result on foreign key dependencies. + the result on foreign key dependencies. Does not automatically + resolve cycles, and will raise :class:`.CircularDependencyError` + if cycles exist. + + .. deprecated:: 1.0.0 - see + :meth:`.Inspector.get_sorted_table_and_fkc_names` for a version + of this which resolves foreign key cycles between tables + automatically. + + .. versionchanged:: 0.8 the "foreign_key" sorting sorts tables + in order of dependee to dependent; that is, in creation + order, rather than in drop order. This is to maintain + consistency with similar features such as + :attr:`.MetaData.sorted_tables` and :func:`.util.sort_tables`. + + .. seealso:: + + :meth:`.Inspector.get_sorted_table_and_fkc_names` + + :attr:`.MetaData.sorted_tables` - This should probably not return view names or maybe it should return - them with an indicator t or v. """ if hasattr(self.dialect, 'get_table_names'): - tnames = self.dialect.get_table_names(self.conn, - schema, - info_cache=self.info_cache) + tnames = self.dialect.get_table_names( + self.bind, schema, info_cache=self.info_cache) else: tnames = self.engine.table_names(schema) if order_by == 'foreign_key': - ordered_tnames = tnames[:] - # Order based on foreign key dependencies. + tuples = [] for tname in tnames: - table_pos = tnames.index(tname) - fkeys = self.get_foreign_keys(tname, schema) - for fkey in fkeys: - rtable = fkey['referred_table'] - if rtable in ordered_tnames: - ref_pos = ordered_tnames.index(rtable) - # Make sure it's lower in the list than anything it - # references. - if table_pos > ref_pos: - ordered_tnames.pop(table_pos) # rtable moves up 1 - # insert just below rtable - ordered_tnames.index(ref_pos, tname) - tnames = ordered_tnames + for fkey in self.get_foreign_keys(tname, schema): + if tname != fkey['referred_table']: + tuples.append((fkey['referred_table'], tname)) + tnames = list(topological.sort(tuples, tnames)) return tnames + def get_sorted_table_and_fkc_names(self, schema=None): + """Return dependency-sorted table and foreign key constraint names in + referred to within a particular schema. + + This will yield 2-tuples of + ``(tablename, [(tname, fkname), (tname, fkname), ...])`` + consisting of table names in CREATE order grouped with the foreign key + constraint names that are not detected as belonging to a cycle. + The final element + will be ``(None, [(tname, fkname), (tname, fkname), ..])`` + which will consist of remaining + foreign key constraint names that would require a separate CREATE + step after-the-fact, based on dependencies between tables. + + .. versionadded:: 1.0.- + + .. seealso:: + + :meth:`.Inspector.get_table_names` + + :func:`.sort_tables_and_constraints` - similar method which works + with an already-given :class:`.MetaData`. + + """ + if hasattr(self.dialect, 'get_table_names'): + tnames = self.dialect.get_table_names( + self.bind, schema, info_cache=self.info_cache) + else: + tnames = self.engine.table_names(schema) + + tuples = set() + remaining_fkcs = set() + + fknames_for_table = {} + for tname in tnames: + fkeys = self.get_foreign_keys(tname, schema) + fknames_for_table[tname] = set( + [fk['name'] for fk in fkeys] + ) + for fkey in fkeys: + if tname != fkey['referred_table']: + tuples.add((fkey['referred_table'], tname)) + try: + candidate_sort = list(topological.sort(tuples, tnames)) + except exc.CircularDependencyError as err: + for edge in err.edges: + tuples.remove(edge) + remaining_fkcs.update( + (edge[1], fkc) + for fkc in fknames_for_table[edge[1]] + ) + + candidate_sort = list(topological.sort(tuples, tnames)) + return [ + (tname, fknames_for_table[tname].difference(remaining_fkcs)) + for tname in candidate_sort + ] + [(None, list(remaining_fkcs))] + + def get_temp_table_names(self): + """return a list of temporary table names for the current bind. + + This method is unsupported by most dialects; currently + only SQLite implements it. + + .. versionadded:: 1.0.0 + + """ + return self.dialect.get_temp_table_names( + self.bind, info_cache=self.info_cache) + + def get_temp_view_names(self): + """return a list of temporary view names for the current bind. + + This method is unsupported by most dialects; currently + only SQLite implements it. + + .. versionadded:: 1.0.0 + + """ + return self.dialect.get_temp_view_names( + self.bind, info_cache=self.info_cache) + def get_table_options(self, table_name, schema=None, **kw): + """Return a dictionary of options specified when the table of the + given name was created. + + This currently includes some options that apply to MySQL tables. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ if hasattr(self.dialect, 'get_table_options'): - return self.dialect.get_table_options(self.conn, table_name, schema, - info_cache=self.info_cache, - **kw) + return self.dialect.get_table_options( + self.bind, table_name, schema, + info_cache=self.info_cache, **kw) return {} def get_view_names(self, schema=None): """Return all view names in `schema`. :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + """ - return self.dialect.get_view_names(self.conn, schema, - info_cache=self.info_cache) + return self.dialect.get_view_names(self.bind, schema, + info_cache=self.info_cache) def get_view_definition(self, view_name, schema=None): """Return definition for `view_name`. :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + """ return self.dialect.get_view_definition( - self.conn, view_name, schema, info_cache=self.info_cache) + self.bind, view_name, schema, info_cache=self.info_cache) def get_columns(self, table_name, schema=None, **kw): """Return information about columns in `table_name`. @@ -150,23 +340,31 @@ class Inspector(object): Given a string `table_name` and an optional string `schema`, return column information as a list of dicts with these keys: - name - the column's name + * ``name`` - the column's name - type + * ``type`` - the type of this column; an instance of :class:`~sqlalchemy.types.TypeEngine` - nullable - boolean + * ``nullable`` - boolean flag if the column is NULL or NOT NULL - default - the column's default value + * ``default`` - the column's server default value - this is returned + as a string SQL expression. + + * ``attrs`` - dict containing optional column attributes + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :return: list of dictionaries, each representing the definition of + a database column. - attrs - dict containing optional column attributes """ - col_defs = self.dialect.get_columns(self.conn, table_name, schema, + col_defs = self.dialect.get_columns(self.bind, table_name, schema, info_cache=self.info_cache, **kw) for col_def in col_defs: @@ -176,6 +374,8 @@ class Inspector(object): col_def['type'] = coltype() return col_defs + @deprecated('0.7', 'Call to deprecated method get_primary_keys.' + ' Use get_pk_constraint instead.') def get_primary_keys(self, table_name, schema=None, **kw): """Return information about primary keys in `table_name`. @@ -183,12 +383,34 @@ class Inspector(object): primary key information as a list of column names. """ - pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema, + return self.dialect.get_pk_constraint(self.bind, table_name, schema, + info_cache=self.info_cache, + **kw)['constrained_columns'] + + def get_pk_constraint(self, table_name, schema=None, **kw): + """Return information about primary key constraint on `table_name`. + + Given a string `table_name`, and an optional string `schema`, return + primary key information as a dictionary with these keys: + + constrained_columns + a list of column names that make up the primary key + + name + optional name of the primary key constraint. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ + return self.dialect.get_pk_constraint(self.bind, table_name, schema, info_cache=self.info_cache, **kw) - return pkeys - def get_foreign_keys(self, table_name, schema=None, **kw): """Return information about foreign_keys in `table_name`. @@ -208,15 +430,21 @@ class Inspector(object): a list of column names in the referred table that correspond to constrained_columns - \**kw - other options passed to the dialect's get_foreign_keys() method. + name + optional name of the foreign key constraint. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. """ - fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema, - info_cache=self.info_cache, - **kw) - return fk_defs + return self.dialect.get_foreign_keys(self.bind, table_name, schema, + info_cache=self.info_cache, + **kw) def get_indexes(self, table_name, schema=None, **kw): """Return information about indexes in `table_name`. @@ -232,104 +460,261 @@ class Inspector(object): unique boolean - - \**kw - other options passed to the dialect's get_indexes() method. + + dialect_options + dict of dialect-specific index options. May not be present + for all dialects. + + .. versionadded:: 1.0.0 + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + """ - indexes = self.dialect.get_indexes(self.conn, table_name, - schema, - info_cache=self.info_cache, **kw) - return indexes + return self.dialect.get_indexes(self.bind, table_name, + schema, + info_cache=self.info_cache, **kw) - def reflecttable(self, table, include_columns): + def get_unique_constraints(self, table_name, schema=None, **kw): + """Return information about unique constraints in `table_name`. - dialect = self.conn.dialect + Given a string `table_name` and an optional string `schema`, return + unique constraint information as a list of dicts with these keys: - # MySQL dialect does this. Applicable with other dialects? - if hasattr(dialect, '_connection_charset') \ - and hasattr(dialect, '_adjust_casing'): - charset = dialect._connection_charset - dialect._adjust_casing(table) + name + the unique constraint's name - # table attributes we might need. - reflection_options = dict( - (k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs) + column_names + list of column names in order + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + .. versionadded:: 0.8.4 + + """ + + return self.dialect.get_unique_constraints( + self.bind, table_name, schema, info_cache=self.info_cache, **kw) + + def get_check_constraints(self, table_name, schema=None, **kw): + """Return information about check constraints in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + check constraint information as a list of dicts with these keys: + + name + the check constraint's name + + sqltext + the check constraint's SQL expression + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + .. versionadded:: 1.1.0 + + """ + + return self.dialect.get_check_constraints( + self.bind, table_name, schema, info_cache=self.info_cache, **kw) + + def reflecttable(self, table, include_columns, exclude_columns=(), + _extend_on=None): + """Given a Table object, load its internal constructs based on + introspection. + + This is the underlying method used by most dialects to produce + table reflection. Direct usage is like:: + + from sqlalchemy import create_engine, MetaData, Table + from sqlalchemy.engine import reflection + + engine = create_engine('...') + meta = MetaData() + user_table = Table('user', meta) + insp = Inspector.from_engine(engine) + insp.reflecttable(user_table, None) + + :param table: a :class:`~sqlalchemy.schema.Table` instance. + :param include_columns: a list of string column names to include + in the reflection process. If ``None``, all columns are reflected. + + """ + + if _extend_on is not None: + if table in _extend_on: + return + else: + _extend_on.add(table) + + dialect = self.bind.dialect + + schema = self.bind.schema_for_object(table) - schema = table.schema table_name = table.name - # apply table options - tbl_opts = self.get_table_options(table_name, schema, **table.kwargs) + # get table-level arguments that are specifically + # intended for reflection, e.g. oracle_resolve_synonyms. + # these are unconditionally passed to related Table + # objects + reflection_options = dict( + (k, table.dialect_kwargs.get(k)) + for k in dialect.reflection_options + if k in table.dialect_kwargs + ) + + # reflect table options, like mysql_engine + tbl_opts = self.get_table_options( + table_name, schema, **table.dialect_kwargs) if tbl_opts: - table.kwargs.update(tbl_opts) + # add additional kwargs to the Table if the dialect + # returned them + table._validate_dialect_kwargs(tbl_opts) - # table.kwargs will need to be passed to each reflection method. Make - # sure keywords are strings. - tblkw = table.kwargs.copy() - for (k, v) in tblkw.items(): - del tblkw[k] - tblkw[str(k)] = v + if util.py2k: + if isinstance(schema, str): + schema = schema.decode(dialect.encoding) + if isinstance(table_name, str): + table_name = table_name.decode(dialect.encoding) - # Py2K - if isinstance(schema, str): - schema = schema.decode(dialect.encoding) - if isinstance(table_name, str): - table_name = table_name.decode(dialect.encoding) - # end Py2K - - # columns found_table = False - for col_d in self.get_columns(table_name, schema, **tblkw): - found_table = True - name = col_d['name'] - if include_columns and name not in include_columns: - continue + cols_by_orig_name = {} - coltype = col_d['type'] - col_kw = { - 'nullable':col_d['nullable'], - } - if 'autoincrement' in col_d: - col_kw['autoincrement'] = col_d['autoincrement'] - if 'quote' in col_d: - col_kw['quote'] = col_d['quote'] - - colargs = [] - if col_d.get('default') is not None: - # the "default" value is assumed to be a literal SQL expression, - # so is wrapped in text() so that no quoting occurs on re-issuance. - colargs.append(sa_schema.DefaultClause(sql.text(col_d['default']))) - - if 'sequence' in col_d: - # TODO: mssql, maxdb and sybase are using this. - seq = col_d['sequence'] - sequence = sa_schema.Sequence(seq['name'], 1, 1) - if 'start' in seq: - sequence.start = seq['start'] - if 'increment' in seq: - sequence.increment = seq['increment'] - colargs.append(sequence) - - col = sa_schema.Column(name, coltype, *colargs, **col_kw) - table.append_column(col) + for col_d in self.get_columns( + table_name, schema, **table.dialect_kwargs): + found_table = True + + self._reflect_column( + table, col_d, include_columns, + exclude_columns, cols_by_orig_name) if not found_table: raise exc.NoSuchTableError(table.name) - # Primary keys - primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[ - table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw) - if pk in table.c - ]) + self._reflect_pk( + table_name, schema, table, cols_by_orig_name, exclude_columns) - table.append_constraint(primary_key_constraint) + self._reflect_fk( + table_name, schema, table, cols_by_orig_name, + exclude_columns, _extend_on, reflection_options) - # Foreign keys - fkeys = self.get_foreign_keys(table_name, schema, **tblkw) + self._reflect_indexes( + table_name, schema, table, cols_by_orig_name, + include_columns, exclude_columns, reflection_options) + + self._reflect_unique_constraints( + table_name, schema, table, cols_by_orig_name, + include_columns, exclude_columns, reflection_options) + + self._reflect_check_constraints( + table_name, schema, table, cols_by_orig_name, + include_columns, exclude_columns, reflection_options) + + def _reflect_column( + self, table, col_d, include_columns, + exclude_columns, cols_by_orig_name): + + orig_name = col_d['name'] + + table.dispatch.column_reflect(self, table, col_d) + + # fetch name again as column_reflect is allowed to + # change it + name = col_d['name'] + if (include_columns and name not in include_columns) \ + or (exclude_columns and name in exclude_columns): + return + + coltype = col_d['type'] + + col_kw = dict( + (k, col_d[k]) + for k in ['nullable', 'autoincrement', 'quote', 'info', 'key'] + if k in col_d + ) + + colargs = [] + if col_d.get('default') is not None: + default = col_d['default'] + if isinstance(default, sql.elements.TextClause): + default = sa_schema.DefaultClause(default, _reflected=True) + elif not isinstance(default, sa_schema.FetchedValue): + default = sa_schema.DefaultClause( + sql.text(col_d['default']), _reflected=True) + + colargs.append(default) + + if 'sequence' in col_d: + self._reflect_col_sequence(col_d, colargs) + + cols_by_orig_name[orig_name] = col = \ + sa_schema.Column(name, coltype, *colargs, **col_kw) + + if col.key in table.primary_key: + col.primary_key = True + table.append_column(col) + + def _reflect_col_sequence(self, col_d, colargs): + if 'sequence' in col_d: + # TODO: mssql and sybase are using this. + seq = col_d['sequence'] + sequence = sa_schema.Sequence(seq['name'], 1, 1) + if 'start' in seq: + sequence.start = seq['start'] + if 'increment' in seq: + sequence.increment = seq['increment'] + colargs.append(sequence) + + def _reflect_pk( + self, table_name, schema, table, + cols_by_orig_name, exclude_columns): + pk_cons = self.get_pk_constraint( + table_name, schema, **table.dialect_kwargs) + if pk_cons: + pk_cols = [ + cols_by_orig_name[pk] + for pk in pk_cons['constrained_columns'] + if pk in cols_by_orig_name and pk not in exclude_columns + ] + + # update pk constraint name + table.primary_key.name = pk_cons.get('name') + + # tell the PKConstraint to re-initialize + # its column collection + table.primary_key._reload(pk_cols) + + def _reflect_fk( + self, table_name, schema, table, cols_by_orig_name, + exclude_columns, _extend_on, reflection_options): + fkeys = self.get_foreign_keys( + table_name, schema, **table.dialect_kwargs) for fkey_d in fkeys: conname = fkey_d['name'] - constrained_columns = fkey_d['constrained_columns'] + # look for columns by orig name in cols_by_orig_name, + # but support columns that are in-Python only as fallback + constrained_columns = [ + cols_by_orig_name[c].key + if c in cols_by_orig_name else c + for c in fkey_d['constrained_columns'] + ] + if exclude_columns and set(constrained_columns).intersection( + exclude_columns): + continue referred_schema = fkey_d['referred_schema'] referred_table = fkey_d['referred_table'] referred_columns = fkey_d['referred_columns'] @@ -337,7 +722,8 @@ class Inspector(object): if referred_schema is not None: sa_schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, - autoload_with=self.conn, + autoload_with=self.bind, + _extend_on=_extend_on, **reflection_options ) for column in referred_columns: @@ -345,26 +731,113 @@ class Inspector(object): [referred_schema, referred_table, column])) else: sa_schema.Table(referred_table, table.metadata, autoload=True, - autoload_with=self.conn, + autoload_with=self.bind, + schema=sa_schema.BLANK_SCHEMA, + _extend_on=_extend_on, **reflection_options ) for column in referred_columns: refspec.append(".".join([referred_table, column])) + if 'options' in fkey_d: + options = fkey_d['options'] + else: + options = {} table.append_constraint( sa_schema.ForeignKeyConstraint(constrained_columns, refspec, - conname, link_to_name=True)) - # Indexes + conname, link_to_name=True, + **options)) + + def _reflect_indexes( + self, table_name, schema, table, cols_by_orig_name, + include_columns, exclude_columns, reflection_options): + # Indexes indexes = self.get_indexes(table_name, schema) for index_d in indexes: name = index_d['name'] columns = index_d['column_names'] unique = index_d['unique'] - flavor = index_d.get('type', 'unknown type') + flavor = index_d.get('type', 'index') + dialect_options = index_d.get('dialect_options', {}) + + duplicates = index_d.get('duplicates_constraint') if include_columns and \ - not set(columns).issubset(include_columns): + not set(columns).issubset(include_columns): util.warn( - "Omitting %s KEY for (%s), key covers omitted columns." % + "Omitting %s key for (%s), key covers omitted columns." % (flavor, ', '.join(columns))) continue - sa_schema.Index(name, *[table.columns[c] for c in columns], - **dict(unique=unique)) + if duplicates: + continue + # look for columns by orig name in cols_by_orig_name, + # but support columns that are in-Python only as fallback + idx_cols = [] + for c in columns: + try: + idx_col = cols_by_orig_name[c] \ + if c in cols_by_orig_name else table.c[c] + except KeyError: + util.warn( + "%s key '%s' was not located in " + "columns for table '%s'" % ( + flavor, c, table_name + )) + else: + idx_cols.append(idx_col) + + sa_schema.Index( + name, *idx_cols, + **dict(list(dialect_options.items()) + [('unique', unique)]) + ) + + def _reflect_unique_constraints( + self, table_name, schema, table, cols_by_orig_name, + include_columns, exclude_columns, reflection_options): + + # Unique Constraints + try: + constraints = self.get_unique_constraints(table_name, schema) + except NotImplementedError: + # optional dialect feature + return + + for const_d in constraints: + conname = const_d['name'] + columns = const_d['column_names'] + duplicates = const_d.get('duplicates_index') + if include_columns and \ + not set(columns).issubset(include_columns): + util.warn( + "Omitting unique constraint key for (%s), " + "key covers omitted columns." % + ', '.join(columns)) + continue + if duplicates: + continue + # look for columns by orig name in cols_by_orig_name, + # but support columns that are in-Python only as fallback + constrained_cols = [] + for c in columns: + try: + constrained_col = cols_by_orig_name[c] \ + if c in cols_by_orig_name else table.c[c] + except KeyError: + util.warn( + "unique constraint key '%s' was not located in " + "columns for table '%s'" % (c, table_name)) + else: + constrained_cols.append(constrained_col) + table.append_constraint( + sa_schema.UniqueConstraint(*constrained_cols, name=conname)) + + def _reflect_check_constraints( + self, table_name, schema, table, cols_by_orig_name, + include_columns, exclude_columns, reflection_options): + try: + constraints = self.get_check_constraints(table_name, schema) + except NotImplementedError: + # optional dialect feature + return + + for const_d in constraints: + table.append_constraint( + sa_schema.CheckConstraint(**const_d)) diff --git a/sqlalchemy/engine/strategies.py b/sqlalchemy/engine/strategies.py index 7fc39b9..81bb2c5 100644 --- a/sqlalchemy/engine/strategies.py +++ b/sqlalchemy/engine/strategies.py @@ -1,3 +1,10 @@ +# engine/strategies.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + """Strategies for creating new instances of Engine types. These are semi-private implementation classes which provide the @@ -11,18 +18,19 @@ New strategies can be added via new ``EngineStrategy`` classes. from operator import attrgetter from sqlalchemy.engine import base, threadlocal, url -from sqlalchemy import util, exc +from sqlalchemy import util, event from sqlalchemy import pool as poollib +from sqlalchemy.sql import schema strategies = {} class EngineStrategy(object): - """An adaptor that processes input arguements and produces an Engine. + """An adaptor that processes input arguments and produces an Engine. Provides a ``create`` method that receives input arguments and produces an instance of base.Engine or a subclass. - + """ def __init__(self): @@ -35,58 +43,75 @@ class EngineStrategy(object): class DefaultEngineStrategy(EngineStrategy): - """Base class for built-in stratgies.""" + """Base class for built-in strategies.""" - pool_threadlocal = False - def create(self, name_or_url, **kwargs): # create url.URL object u = url.make_url(name_or_url) - dialect_cls = u.get_dialect() + plugins = u._instantiate_plugins(kwargs) + + u.query.pop('plugin', None) + + entrypoint = u._get_entrypoint() + dialect_cls = entrypoint.get_dialect_cls(u) + + if kwargs.pop('_coerce_config', False): + def pop_kwarg(key, default=None): + value = kwargs.pop(key, default) + if key in dialect_cls.engine_config_types: + value = dialect_cls.engine_config_types[key](value) + return value + else: + pop_kwarg = kwargs.pop dialect_args = {} # consume dialect arguments from kwargs for k in util.get_cls_kwargs(dialect_cls): if k in kwargs: - dialect_args[k] = kwargs.pop(k) + dialect_args[k] = pop_kwarg(k) dbapi = kwargs.pop('module', None) if dbapi is None: dbapi_args = {} for k in util.get_func_kwargs(dialect_cls.dbapi): if k in kwargs: - dbapi_args[k] = kwargs.pop(k) + dbapi_args[k] = pop_kwarg(k) dbapi = dialect_cls.dbapi(**dbapi_args) dialect_args['dbapi'] = dbapi + for plugin in plugins: + plugin.handle_dialect_kwargs(dialect_cls, dialect_args) + # create dialect dialect = dialect_cls(**dialect_args) # assemble connection arguments (cargs, cparams) = dialect.create_connect_args(u) - cparams.update(kwargs.pop('connect_args', {})) + cparams.update(pop_kwarg('connect_args', {})) + cargs = list(cargs) # allow mutability # look for existing pool or create - pool = kwargs.pop('pool', None) + pool = pop_kwarg('pool', None) if pool is None: - def connect(): - try: - return dialect.connect(*cargs, **cparams) - except Exception, e: - # Py3K - #raise exc.DBAPIError.instance(None, None, e) from e - # Py2K - import sys - raise exc.DBAPIError.instance(None, None, e), None, sys.exc_info()[2] - # end Py2K - - creator = kwargs.pop('creator', connect) + def connect(connection_record=None): + if dialect._has_events: + for fn in dialect.dispatch.do_connect: + connection = fn( + dialect, connection_record, cargs, cparams) + if connection is not None: + return connection + return dialect.connect(*cargs, **cparams) - poolclass = (kwargs.pop('poolclass', None) or - getattr(dialect_cls, 'poolclass', poollib.QueuePool)) - pool_args = {} + creator = pop_kwarg('creator', connect) + + poolclass = pop_kwarg('poolclass', None) + if poolclass is None: + poolclass = dialect_cls.get_pool_class(u) + pool_args = { + 'dialect': dialect + } # consume pool arguments from kwargs, translating a few of # the arguments @@ -94,12 +119,17 @@ class DefaultEngineStrategy(EngineStrategy): 'echo': 'echo_pool', 'timeout': 'pool_timeout', 'recycle': 'pool_recycle', - 'use_threadlocal':'pool_threadlocal'} + 'events': 'pool_events', + 'use_threadlocal': 'pool_threadlocal', + 'reset_on_return': 'pool_reset_on_return'} for k in util.get_cls_kwargs(poolclass): tk = translate.get(k, k) if tk in kwargs: - pool_args[k] = kwargs.pop(tk) - pool_args.setdefault('use_threadlocal', self.pool_threadlocal) + pool_args[k] = pop_kwarg(tk) + + for plugin in plugins: + plugin.handle_pool_kwargs(poolclass, pool_args) + pool = poolclass(creator, **pool_args) else: if isinstance(pool, poollib._DBProxy): @@ -107,15 +137,17 @@ class DefaultEngineStrategy(EngineStrategy): else: pool = pool + pool._dialect = dialect + # create engine. engineclass = self.engine_cls engine_args = {} for k in util.get_cls_kwargs(engineclass): if k in kwargs: - engine_args[k] = kwargs.pop(k) + engine_args[k] = pop_kwarg(k) _initialize = kwargs.pop('_initialize', True) - + # all kwargs should be consumed if kwargs: raise TypeError( @@ -126,24 +158,35 @@ class DefaultEngineStrategy(EngineStrategy): dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__)) - + engine = engineclass(pool, dialect, u, **engine_args) if _initialize: do_on_connect = dialect.on_connect() if do_on_connect: - def on_connect(conn, rec): - conn = getattr(conn, '_sqla_unwrap', conn) + def on_connect(dbapi_connection, connection_record): + conn = getattr( + dbapi_connection, '_sqla_unwrap', dbapi_connection) if conn is None: return do_on_connect(conn) - - pool.add_listener({'first_connect': on_connect, 'connect':on_connect}) - - def first_connect(conn, rec): - c = base.Connection(engine, connection=conn) + + event.listen(pool, 'first_connect', on_connect) + event.listen(pool, 'connect', on_connect) + + def first_connect(dbapi_connection, connection_record): + c = base.Connection(engine, connection=dbapi_connection, + _has_events=False) + c._execution_options = util.immutabledict() dialect.initialize(c) - pool.add_listener({'first_connect':first_connect}) + event.listen(pool, 'first_connect', first_connect, once=True) + + dialect_cls.engine_created(engine) + if entrypoint is not dialect_cls: + entrypoint.engine_created(engine) + + for plugin in plugins: + plugin.engine_created(engine) return engine @@ -153,15 +196,14 @@ class PlainEngineStrategy(DefaultEngineStrategy): name = 'plain' engine_cls = base.Engine - + PlainEngineStrategy() class ThreadLocalEngineStrategy(DefaultEngineStrategy): - """Strategy for configuring an Engine with thredlocal behavior.""" - + """Strategy for configuring an Engine with threadlocal behavior.""" + name = 'threadlocal' - pool_threadlocal = True engine_cls = threadlocal.TLEngine ThreadLocalEngineStrategy() @@ -172,11 +214,11 @@ class MockEngineStrategy(EngineStrategy): Produces a single mock Connectable object which dispatches statement execution to a passed-in function. - + """ name = 'mock' - + def create(self, name_or_url, executor, **kwargs): # create url.URL object u = url.make_url(name_or_url) @@ -203,9 +245,14 @@ class MockEngineStrategy(EngineStrategy): dialect = property(attrgetter('_dialect')) name = property(lambda s: s._dialect.name) + schema_for_object = schema._schema_getter(None) + def contextual_connect(self, **kwargs): return self + def execution_options(self, **kw): + return self + def compiler(self, statement, parameters, **kwargs): return self._dialect.compiler( statement, parameters, engine=self, **kwargs) @@ -213,13 +260,22 @@ class MockEngineStrategy(EngineStrategy): def create(self, entity, **kwargs): kwargs['checkfirst'] = False from sqlalchemy.engine import ddl - - ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity) + + ddl.SchemaGenerator( + self.dialect, self, **kwargs).traverse_single(entity) def drop(self, entity, **kwargs): kwargs['checkfirst'] = False from sqlalchemy.engine import ddl - ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity) + ddl.SchemaDropper( + self.dialect, self, **kwargs).traverse_single(entity) + + def _run_visitor(self, visitorcallable, element, + connection=None, + **kwargs): + kwargs['checkfirst'] = False + visitorcallable(self.dialect, self, + **kwargs).traverse_single(element) def execute(self, object, *multiparams, **params): raise NotImplementedError() diff --git a/sqlalchemy/engine/threadlocal.py b/sqlalchemy/engine/threadlocal.py index 001caee..ee31764 100644 --- a/sqlalchemy/engine/threadlocal.py +++ b/sqlalchemy/engine/threadlocal.py @@ -1,23 +1,33 @@ +# engine/threadlocal.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + """Provides a thread-local transactional wrapper around the root Engine class. -The ``threadlocal`` module is invoked when using the ``strategy="threadlocal"`` flag -with :func:`~sqlalchemy.engine.create_engine`. This module is semi-private and is -invoked automatically when the threadlocal engine strategy is used. +The ``threadlocal`` module is invoked when using the +``strategy="threadlocal"`` flag with :func:`~sqlalchemy.engine.create_engine`. +This module is semi-private and is invoked automatically when the threadlocal +engine strategy is used. """ -from sqlalchemy import util -from sqlalchemy.engine import base +from .. import util +from . import base import weakref + class TLConnection(base.Connection): + def __init__(self, *arg, **kw): super(TLConnection, self).__init__(*arg, **kw) self.__opencount = 0 - + def _increment_connect(self): self.__opencount += 1 return self - + def close(self): if self.__opencount == 1: base.Connection.close(self) @@ -27,70 +37,95 @@ class TLConnection(base.Connection): self.__opencount = 0 base.Connection.close(self) - -class TLEngine(base.Engine): - """An Engine that includes support for thread-local managed transactions.""" +class TLEngine(base.Engine): + """An Engine that includes support for thread-local managed + transactions. + + """ + _tl_connection_cls = TLConnection def __init__(self, *args, **kwargs): super(TLEngine, self).__init__(*args, **kwargs) self._connections = util.threading.local() - proxy = kwargs.get('proxy') - if proxy: - self.TLConnection = base._proxy_connection_cls(TLConnection, proxy) - else: - self.TLConnection = TLConnection def contextual_connect(self, **kw): if not hasattr(self._connections, 'conn'): connection = None else: connection = self._connections.conn() - + if connection is None or connection.closed: # guards against pool-level reapers, if desired. # or not connection.connection.is_valid: - connection = self.TLConnection(self, self.pool.connect(), **kw) - self._connections.conn = conn = weakref.ref(connection) - + connection = self._tl_connection_cls( + self, + self._wrap_pool_connect( + self.pool.connect, connection), + **kw) + self._connections.conn = weakref.ref(connection) + return connection._increment_connect() - + def begin_twophase(self, xid=None): if not hasattr(self._connections, 'trans'): self._connections.trans = [] - self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid)) + self._connections.trans.append( + self.contextual_connect().begin_twophase(xid=xid)) + return self def begin_nested(self): if not hasattr(self._connections, 'trans'): self._connections.trans = [] - self._connections.trans.append(self.contextual_connect().begin_nested()) - + self._connections.trans.append( + self.contextual_connect().begin_nested()) + return self + def begin(self): if not hasattr(self._connections, 'trans'): self._connections.trans = [] self._connections.trans.append(self.contextual_connect().begin()) - + return self + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + if type is None: + self.commit() + else: + self.rollback() + def prepare(self): + if not hasattr(self._connections, 'trans') or \ + not self._connections.trans: + return self._connections.trans[-1].prepare() - + def commit(self): + if not hasattr(self._connections, 'trans') or \ + not self._connections.trans: + return trans = self._connections.trans.pop(-1) trans.commit() - + def rollback(self): + if not hasattr(self._connections, 'trans') or \ + not self._connections.trans: + return trans = self._connections.trans.pop(-1) trans.rollback() - + def dispose(self): self._connections = util.threading.local() super(TLEngine, self).dispose() - + @property def closed(self): return not hasattr(self._connections, 'conn') or \ - self._connections.conn() is None or \ - self._connections.conn().closed - + self._connections.conn() is None or \ + self._connections.conn().closed + def close(self): if not self.closed: self.contextual_connect().close() @@ -98,6 +133,6 @@ class TLEngine(base.Engine): connection._force_close() del self._connections.conn self._connections.trans = [] - + def __repr__(self): - return 'TLEngine(%s)' % str(self.url) + return 'TLEngine(%r)' % self.url diff --git a/sqlalchemy/engine/url.py b/sqlalchemy/engine/url.py index 5d658d7..1c16584 100644 --- a/sqlalchemy/engine/url.py +++ b/sqlalchemy/engine/url.py @@ -1,13 +1,23 @@ +# engine/url.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + """Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates information about a database connection specification. -The URL object is created automatically when :func:`~sqlalchemy.engine.create_engine` is called -with a string argument; alternatively, the URL is a public-facing construct which can +The URL object is created automatically when +:func:`~sqlalchemy.engine.create_engine` is called with a string +argument; alternatively, the URL is a public-facing construct which can be used directly and is also accepted directly by ``create_engine()``. """ -import re, cgi, sys, urllib -from sqlalchemy import exc +import re +from .. import exc, util +from . import Dialect +from ..dialects import registry, plugins class URL(object): @@ -15,8 +25,8 @@ class URL(object): Represent the components of a URL used to connect to a database. This object is suitable to be passed directly to a - ``create_engine()`` call. The fields of the URL are parsed from a - string by the ``module-level make_url()`` function. the string + :func:`~sqlalchemy.create_engine` call. The fields of the URL are parsed + from a string by the :func:`.make_url` function. the string format of the URL is an RFC-1738-style string. All initialization parameters are available as public attributes. @@ -53,25 +63,35 @@ class URL(object): self.database = database self.query = query or {} - def __str__(self): + def __to_string__(self, hide_password=True): s = self.drivername + "://" if self.username is not None: - s += self.username + s += _rfc_1738_quote(self.username) if self.password is not None: - s += ':' + urllib.quote_plus(self.password) + s += ':' + ('***' if hide_password + else _rfc_1738_quote(self.password)) s += "@" if self.host is not None: - s += self.host + if ':' in self.host: + s += "[%s]" % self.host + else: + s += self.host if self.port is not None: s += ':' + str(self.port) if self.database is not None: s += '/' + self.database if self.query: - keys = self.query.keys() + keys = list(self.query) keys.sort() s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys) return s + def __str__(self): + return self.__to_string__(hide_password=False) + + def __repr__(self): + return self.__to_string__() + def __hash__(self): return hash(str(self)) @@ -85,49 +105,58 @@ class URL(object): self.database == other.database and \ self.query == other.query + def get_backend_name(self): + if '+' not in self.drivername: + return self.drivername + else: + return self.drivername.split('+')[0] + + def get_driver_name(self): + if '+' not in self.drivername: + return self.get_dialect().driver + else: + return self.drivername.split('+')[1] + + def _instantiate_plugins(self, kwargs): + plugin_names = util.to_list(self.query.get('plugin', ())) + + return [ + plugins.load(plugin_name)(self, kwargs) + for plugin_name in plugin_names + ] + + def _get_entrypoint(self): + """Return the "entry point" dialect class. + + This is normally the dialect itself except in the case when the + returned class implements the get_dialect_cls() method. + + """ + if '+' not in self.drivername: + name = self.drivername + else: + name = self.drivername.replace('+', '.') + cls = registry.load(name) + # check for legacy dialects that + # would return a module with 'dialect' as the + # actual class + if hasattr(cls, 'dialect') and \ + isinstance(cls.dialect, type) and \ + issubclass(cls.dialect, Dialect): + return cls.dialect + else: + return cls + def get_dialect(self): """Return the SQLAlchemy database dialect class corresponding to this URL's driver name. """ + entrypoint = self._get_entrypoint() + dialect_cls = entrypoint.get_dialect_cls(self) + return dialect_cls - try: - if '+' in self.drivername: - dialect, driver = self.drivername.split('+') - else: - dialect, driver = self.drivername, 'base' - - module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects - module = getattr(module, dialect) - module = getattr(module, driver) - - return module.dialect - except ImportError: - module = self._load_entry_point() - if module is not None: - return module - else: - raise - - def _load_entry_point(self): - """attempt to load this url's dialect from entry points, or return None - if pkg_resources is not installed or there is no matching entry point. - - Raise ImportError if the actual load fails. - - """ - try: - import pkg_resources - except ImportError: - return None - - for res in pkg_resources.iter_entry_points('sqlalchemy.dialects'): - if res.name == self.drivername: - return res.load() - else: - return None - def translate_connect_args(self, names=[], **kw): - """Translate url attributes into a dictionary of connection arguments. + r"""Translate url attributes into a dictionary of connection arguments. Returns attributes of this url (`host`, `database`, `username`, `password`, `port`) as a plain dictionary. The attribute names are @@ -136,8 +165,8 @@ class URL(object): :param \**kw: Optional, alternate key names for url attributes. - :param names: Deprecated. Same purpose as the keyword-based alternate names, - but correlates the name to the original positionally. + :param names: Deprecated. Same purpose as the keyword-based alternate + names, but correlates the name to the original positionally. """ translated = {} @@ -153,6 +182,7 @@ class URL(object): translated[name] = getattr(self, sname) return translated + def make_url(name_or_url): """Given a string or unicode instance, produce a new URL instance. @@ -160,25 +190,28 @@ def make_url(name_or_url): existing URL object is passed, just returns the object. """ - if isinstance(name_or_url, basestring): + if isinstance(name_or_url, util.string_types): return _parse_rfc1738_args(name_or_url) else: return name_or_url + def _parse_rfc1738_args(name): pattern = re.compile(r''' (?P[\w\+]+):// (?: (?P[^:/]*) - (?::(?P[^/]*))? + (?::(?P.*))? @)? (?: - (?P[^/:]*) + (?: + \[(?P[^/]+)\] | + (?P[^/:]+) + )? (?::(?P[^/]*))? )? (?:/(?P.*))? - ''' - , re.X) + ''', re.X) m = pattern.match(name) if m is not None: @@ -186,29 +219,43 @@ def _parse_rfc1738_args(name): if components['database'] is not None: tokens = components['database'].split('?', 2) components['database'] = tokens[0] - query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None - # Py2K - if query is not None: + query = ( + len(tokens) > 1 and dict(util.parse_qsl(tokens[1]))) or None + if util.py2k and query is not None: query = dict((k.encode('ascii'), query[k]) for k in query) - # end Py2K else: query = None components['query'] = query - if components['password'] is not None: - components['password'] = urllib.unquote_plus(components['password']) + if components['username'] is not None: + components['username'] = _rfc_1738_unquote(components['username']) + if components['password'] is not None: + components['password'] = _rfc_1738_unquote(components['password']) + + ipv4host = components.pop('ipv4host') + ipv6host = components.pop('ipv6host') + components['host'] = ipv4host or ipv6host name = components.pop('name') return URL(name, **components) else: raise exc.ArgumentError( "Could not parse rfc1738 URL from string '%s'" % name) + +def _rfc_1738_quote(text): + return re.sub(r'[:@/]', lambda m: "%%%X" % ord(m.group(0)), text) + + +def _rfc_1738_unquote(text): + return util.unquote(text) + + def _parse_keyvalue_args(name): - m = re.match( r'(\w+)://(.*)', name) + m = re.match(r'(\w+)://(.*)', name) if m is not None: (name, args) = m.group(1, 2) - opts = dict( cgi.parse_qsl( args ) ) + opts = dict(util.parse_qsl(args)) return URL(name, *opts) else: return None diff --git a/sqlalchemy/exc.py b/sqlalchemy/exc.py index 31826f4..b2e07ae 100644 --- a/sqlalchemy/exc.py +++ b/sqlalchemy/exc.py @@ -1,13 +1,15 @@ -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# sqlalchemy/exc.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """Exceptions used with SQLAlchemy. -The base exception class is SQLAlchemyError. Exceptions which are raised as a -result of DBAPI exceptions are all subclasses of -:class:`~sqlalchemy.exc.DBAPIError`. +The base exception class is :exc:`.SQLAlchemyError`. Exceptions which are +raised as a result of DBAPI exceptions are all subclasses of +:exc:`.DBAPIError`. """ @@ -24,31 +26,100 @@ class ArgumentError(SQLAlchemyError): """ +class ObjectNotExecutableError(ArgumentError): + """Raised when an object is passed to .execute() that can't be + executed as SQL. + + .. versionadded:: 1.1 + + """ + + def __init__(self, target): + super(ObjectNotExecutableError, self).__init__( + "Not an executable object: %r" % target + ) + + +class NoSuchModuleError(ArgumentError): + """Raised when a dynamically-loaded module (usually a database dialect) + of a particular name cannot be located.""" + + +class NoForeignKeysError(ArgumentError): + """Raised when no foreign keys can be located between two selectables + during a join.""" + + +class AmbiguousForeignKeysError(ArgumentError): + """Raised when more than one foreign key matching can be located + between two selectables during a join.""" + + class CircularDependencyError(SQLAlchemyError): - """Raised by topological sorts when a circular dependency is detected""" + """Raised by topological sorts when a circular dependency is detected. + + There are two scenarios where this error occurs: + + * In a Session flush operation, if two objects are mutually dependent + on each other, they can not be inserted or deleted via INSERT or + DELETE statements alone; an UPDATE will be needed to post-associate + or pre-deassociate one of the foreign key constrained values. + The ``post_update`` flag described at :ref:`post_update` can resolve + this cycle. + * In a :attr:`.MetaData.sorted_tables` operation, two :class:`.ForeignKey` + or :class:`.ForeignKeyConstraint` objects mutually refer to each + other. Apply the ``use_alter=True`` flag to one or both, + see :ref:`use_alter`. + + """ + def __init__(self, message, cycles, edges, msg=None): + if msg is None: + message += " (%s)" % ", ".join(repr(s) for s in cycles) + else: + message = msg + SQLAlchemyError.__init__(self, message) + self.cycles = cycles + self.edges = edges + + def __reduce__(self): + return self.__class__, (None, self.cycles, + self.edges, self.args[0]) class CompileError(SQLAlchemyError): """Raised when an error occurs during SQL compilation""" + +class UnsupportedCompilationError(CompileError): + """Raised when an operation is not supported by the given compiler. + + + .. versionadded:: 0.8.3 + + """ + + def __init__(self, compiler, element_type): + super(UnsupportedCompilationError, self).__init__( + "Compiler %r can't render element of type %s" % + (compiler, element_type)) + + class IdentifierError(SQLAlchemyError): """Raised when a schema name is beyond the max character limit""" -# Moved to orm.exc; compatability definition installed by orm import until 0.6 -ConcurrentModificationError = None class DisconnectionError(SQLAlchemyError): """A disconnect is detected on a raw DB-API connection. This error is raised and consumed internally by a connection pool. It can - be raised by a ``PoolListener`` so that the host pool forces a disconnect. + be raised by the :meth:`.PoolEvents.checkout` event so that the host pool + forces a retry; the exception will be caught three times in a row before + the pool gives up and raises :class:`~sqlalchemy.exc.InvalidRequestError` + regarding the connection attempt. """ -# Moved to orm.exc; compatability definition installed by orm import until 0.6 -FlushError = None - class TimeoutError(SQLAlchemyError): """Raised when a connection pool times out on getting a connection.""" @@ -60,17 +131,52 @@ class InvalidRequestError(SQLAlchemyError): """ + +class NoInspectionAvailable(InvalidRequestError): + """A subject passed to :func:`sqlalchemy.inspection.inspect` produced + no context for inspection.""" + + +class ResourceClosedError(InvalidRequestError): + """An operation was requested from a connection, cursor, or other + object that's in a closed state.""" + + class NoSuchColumnError(KeyError, InvalidRequestError): """A nonexistent column is requested from a ``RowProxy``.""" + class NoReferenceError(InvalidRequestError): """Raised by ``ForeignKey`` to indicate a reference cannot be resolved.""" - + + class NoReferencedTableError(NoReferenceError): - """Raised by ``ForeignKey`` when the referred ``Table`` cannot be located.""" + """Raised by ``ForeignKey`` when the referred ``Table`` cannot be + located. + + """ + def __init__(self, message, tname): + NoReferenceError.__init__(self, message) + self.table_name = tname + + def __reduce__(self): + return self.__class__, (self.args[0], self.table_name) + class NoReferencedColumnError(NoReferenceError): - """Raised by ``ForeignKey`` when the referred ``Column`` cannot be located.""" + """Raised by ``ForeignKey`` when the referred ``Column`` cannot be + located. + + """ + def __init__(self, message, tname, cname): + NoReferenceError.__init__(self, message) + self.table_name = tname + self.column_name = cname + + def __reduce__(self): + return self.__class__, (self.args[0], self.table_name, + self.column_name) + class NoSuchTableError(InvalidRequestError): """Table does not exist or is not visible to a connection.""" @@ -80,70 +186,161 @@ class UnboundExecutionError(InvalidRequestError): """SQL was attempted without a database connection to execute it on.""" -# Moved to orm.exc; compatability definition installed by orm import until 0.6 +class DontWrapMixin(object): + """A mixin class which, when applied to a user-defined Exception class, + will not be wrapped inside of :exc:`.StatementError` if the error is + emitted within the process of executing a statement. + + E.g.:: + + from sqlalchemy.exc import DontWrapMixin + + class MyCustomException(Exception, DontWrapMixin): + pass + + class MySpecialType(TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + if value == 'invalid': + raise MyCustomException("invalid!") + + """ + +# Moved to orm.exc; compatibility definition installed by orm import until 0.6 UnmappedColumnError = None -class DBAPIError(SQLAlchemyError): + +class StatementError(SQLAlchemyError): + """An error occurred during execution of a SQL statement. + + :class:`StatementError` wraps the exception raised + during execution, and features :attr:`.statement` + and :attr:`.params` attributes which supply context regarding + the specifics of the statement which had an issue. + + The wrapped exception object is available in + the :attr:`.orig` attribute. + + """ + + statement = None + """The string SQL statement being invoked when this exception occurred.""" + + params = None + """The parameter list being used when this exception occurred.""" + + orig = None + """The DBAPI exception object.""" + + def __init__(self, message, statement, params, orig): + SQLAlchemyError.__init__(self, message) + self.statement = statement + self.params = params + self.orig = orig + self.detail = [] + + def add_detail(self, msg): + self.detail.append(msg) + + def __reduce__(self): + return self.__class__, (self.args[0], self.statement, + self.params, self.orig) + + def __str__(self): + from sqlalchemy.sql import util + + details = [SQLAlchemyError.__str__(self)] + if self.statement: + details.append("[SQL: %r]" % self.statement) + if self.params: + params_repr = util._repr_params(self.params, 10) + details.append("[parameters: %r]" % params_repr) + return ' '.join([ + "(%s)" % det for det in self.detail + ] + details) + + def __unicode__(self): + return self.__str__() + + +class DBAPIError(StatementError): """Raised when the execution of a database operation fails. - ``DBAPIError`` wraps exceptions raised by the DB-API underlying the + Wraps exceptions raised by the DB-API underlying the database operation. Driver-specific implementations of the standard DB-API exception types are wrapped by matching sub-types of SQLAlchemy's - ``DBAPIError`` when possible. DB-API's ``Error`` type maps to - ``DBAPIError`` in SQLAlchemy, otherwise the names are identical. Note + :class:`DBAPIError` when possible. DB-API's ``Error`` type maps to + :class:`DBAPIError` in SQLAlchemy, otherwise the names are identical. Note that there is no guarantee that different DB-API implementations will raise the same exception type for any given error condition. - If the error-raising operation occured in the execution of a SQL - statement, that statement and its parameters will be available on - the exception object in the ``statement`` and ``params`` attributes. + :class:`DBAPIError` features :attr:`~.StatementError.statement` + and :attr:`~.StatementError.params` attributes which supply context + regarding the specifics of the statement which had an issue, for the + typical case when the error was raised within the context of + emitting a SQL statement. - The wrapped exception object is available in the ``orig`` attribute. - Its type and properties are DB-API implementation specific. + The wrapped exception object is available in the + :attr:`~.StatementError.orig` attribute. Its type and properties are + DB-API implementation specific. """ @classmethod - def instance(cls, statement, params, orig, connection_invalidated=False): + def instance(cls, statement, params, + orig, dbapi_base_err, + connection_invalidated=False, + dialect=None): # Don't ever wrap these, just return them directly as if # DBAPIError didn't exist. - if isinstance(orig, (KeyboardInterrupt, SystemExit)): + if (isinstance(orig, BaseException) and + not isinstance(orig, Exception)) or \ + isinstance(orig, DontWrapMixin): return orig if orig is not None: - name, glob = orig.__class__.__name__, globals() - if name in glob and issubclass(glob[name], DBAPIError): - cls = glob[name] + # not a DBAPI error, statement is present. + # raise a StatementError + if not isinstance(orig, dbapi_base_err) and statement: + return StatementError( + "(%s.%s) %s" % + (orig.__class__.__module__, orig.__class__.__name__, + orig), + statement, params, orig + ) + + glob = globals() + for super_ in orig.__class__.__mro__: + name = super_.__name__ + if dialect: + name = dialect.dbapi_exception_translation_map.get( + name, name) + if name in glob and issubclass(glob[name], DBAPIError): + cls = glob[name] + break return cls(statement, params, orig, connection_invalidated) + def __reduce__(self): + return self.__class__, (self.statement, self.params, + self.orig, self.connection_invalidated) + def __init__(self, statement, params, orig, connection_invalidated=False): try: text = str(orig) - except (KeyboardInterrupt, SystemExit): - raise - except Exception, e: + except Exception as e: text = 'Error in str() of DB-API-generated exception: ' + str(e) - SQLAlchemyError.__init__( - self, '(%s) %s' % (orig.__class__.__name__, text)) - self.statement = statement - self.params = params - self.orig = orig + StatementError.__init__( + self, + '(%s.%s) %s' % ( + orig.__class__.__module__, orig.__class__.__name__, text, ), + statement, + params, + orig + ) self.connection_invalidated = connection_invalidated - def __str__(self): - if isinstance(self.params, (list, tuple)) and len(self.params) > 10 and isinstance(self.params[0], (list, dict, tuple)): - return ' '.join((SQLAlchemyError.__str__(self), - repr(self.statement), - repr(self.params[:2]), - '... and a total of %i bound parameter sets' % len(self.params))) - return ' '.join((SQLAlchemyError.__str__(self), - repr(self.statement), repr(self.params))) - - -# As of 0.4, SQLError is now DBAPIError. -# SQLError alias will be removed in 0.6. -SQLError = DBAPIError class InterfaceError(DBAPIError): """Wraps a DB-API InterfaceError.""" diff --git a/sqlalchemy/ext/__init__.py b/sqlalchemy/ext/__init__.py index 8b13789..bb9ae58 100644 --- a/sqlalchemy/ext/__init__.py +++ b/sqlalchemy/ext/__init__.py @@ -1 +1,11 @@ +# ext/__init__.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .. import util as _sa_util + +_sa_util.dependencies.resolve_all("sqlalchemy.ext") diff --git a/sqlalchemy/ext/associationproxy.py b/sqlalchemy/ext/associationproxy.py index c7437d7..6f570a1 100644 --- a/sqlalchemy/ext/associationproxy.py +++ b/sqlalchemy/ext/associationproxy.py @@ -1,3 +1,10 @@ +# ext/associationproxy.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + """Contain the ``AssociationProxy`` class. The ``AssociationProxy`` is a Python property object which provides @@ -9,43 +16,37 @@ See the example ``examples/association/proxied_association.py``. import itertools import operator import weakref -from sqlalchemy import exceptions -from sqlalchemy import orm -from sqlalchemy import util -from sqlalchemy.orm import collections -from sqlalchemy.sql import not_ +from .. import exc, orm, util +from ..orm import collections, interfaces +from ..sql import not_, or_ def association_proxy(target_collection, attr, **kw): - """Return a Python property implementing a view of *attr* over a collection. + r"""Return a Python property implementing a view of a target + attribute which references an attribute on members of the + target. - Implements a read/write view over an instance's *target_collection*, - extracting *attr* from each member of the collection. The property acts - somewhat like this list comprehension:: + The returned value is an instance of :class:`.AssociationProxy`. - [getattr(member, *attr*) - for member in getattr(instance, *target_collection*)] + Implements a Python property representing a relationship as a collection + of simpler values, or a scalar value. The proxied property will mimic + the collection type of the target (list, dict or set), or, in the case of + a one to one relationship, a simple scalar value. - Unlike the list comprehension, the collection returned by the property is - always in sync with *target_collection*, and mutations made to either - collection will be reflected in both. + :param target_collection: Name of the attribute we'll proxy to. + This attribute is typically mapped by + :func:`~sqlalchemy.orm.relationship` to link to a target collection, but + can also be a many-to-one or non-scalar relationship. - Implements a Python property representing a relationship as a collection of - simpler values. The proxied property will mimic the collection type of - the target (list, dict or set), or, in the case of a one to one relationship, - a simple scalar value. - - :param target_collection: Name of the relationship attribute we'll proxy to, - usually created with :func:`~sqlalchemy.orm.relationship`. - - :param attr: Attribute on the associated instances we'll proxy for. + :param attr: Attribute on the associated instance or instances we'll + proxy for. For example, given a target collection of [obj1, obj2], a list created by this proxy property would look like [getattr(obj1, *attr*), getattr(obj2, *attr*)] - If the relationship is one-to-one or otherwise uselist=False, then simply: - getattr(obj, *attr*) + If the relationship is one-to-one or otherwise uselist=False, then + simply: getattr(obj, *attr*) :param creator: optional. @@ -69,59 +70,78 @@ def association_proxy(target_collection, attr, **kw): situation. :param \*\*kw: Passes along any other keyword arguments to - :class:`AssociationProxy`. + :class:`.AssociationProxy`. """ return AssociationProxy(target_collection, attr, **kw) -class AssociationProxy(object): +ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY') +"""Symbol indicating an :class:`InspectionAttr` that's + of type :class:`.AssociationProxy`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attibute. + +""" + + +class AssociationProxy(interfaces.InspectionAttrInfo): """A descriptor that presents a read/write view of an object attribute.""" + is_attribute = False + extension_type = ASSOCIATION_PROXY + def __init__(self, target_collection, attr, creator=None, - getset_factory=None, proxy_factory=None, proxy_bulk_set=None): - """Arguments are: + getset_factory=None, proxy_factory=None, + proxy_bulk_set=None, info=None): + """Construct a new :class:`.AssociationProxy`. - target_collection - Name of the collection we'll proxy to, usually created with - 'relationship()' in a mapper setup. + The :func:`.association_proxy` function is provided as the usual + entrypoint here, though :class:`.AssociationProxy` can be instantiated + and/or subclassed directly. - attr - Attribute on the collected instances we'll proxy for. For example, - given a target collection of [obj1, obj2], a list created by this - proxy property would look like [getattr(obj1, attr), getattr(obj2, - attr)] + :param target_collection: Name of the collection we'll proxy to, + usually created with :func:`.relationship`. - creator - Optional. When new items are added to this proxied collection, new - instances of the class collected by the target collection will be - created. For list and set collections, the target class constructor - will be called with the 'value' for the new instance. For dict - types, two arguments are passed: key and value. + :param attr: Attribute on the collected instances we'll proxy + for. For example, given a target collection of [obj1, obj2], a + list created by this proxy property would look like + [getattr(obj1, attr), getattr(obj2, attr)] + + :param creator: Optional. When new items are added to this proxied + collection, new instances of the class collected by the target + collection will be created. For list and set collections, the + target class constructor will be called with the 'value' for the + new instance. For dict types, two arguments are passed: + key and value. If you want to construct instances differently, supply a 'creator' function that takes arguments as above and returns instances. - getset_factory - Optional. Proxied attribute access is automatically handled by - routines that get and set values based on the `attr` argument for - this proxy. + :param getset_factory: Optional. Proxied attribute access is + automatically handled by routines that get and set values based on + the `attr` argument for this proxy. If you would like to customize this behavior, you may supply a `getset_factory` callable that produces a tuple of `getter` and `setter` functions. The factory is called with two arguments, the abstract type of the underlying collection and this proxy instance. - proxy_factory - Optional. The type of collection to emulate is determined by - sniffing the target collection. If your collection type can't be - determined by duck typing or you'd like to use a different - collection implementation, you may supply a factory function to - produce those collections. Only applicable to non-scalar relationships. + :param proxy_factory: Optional. The type of collection to emulate is + determined by sniffing the target collection. If your collection + type can't be determined by duck typing or you'd like to use a + different collection implementation, you may supply a factory + function to produce those collections. Only applicable to + non-scalar relationships. - proxy_bulk_set - Optional, use with proxy_factory. See the _set() method for - details. + :param proxy_bulk_set: Optional, use with proxy_factory. See + the _set() method for details. + + :param info: optional, will be assigned to + :attr:`.AssociationProxy.info` if present. + + .. versionadded:: 1.0.9 """ self.target_collection = target_collection @@ -131,36 +151,107 @@ class AssociationProxy(object): self.proxy_factory = proxy_factory self.proxy_bulk_set = proxy_bulk_set - self.scalar = None self.owning_class = None self.key = '_%s_%s_%s' % ( type(self).__name__, target_collection, id(self)) self.collection_class = None + if info: + self.info = info + + @property + def remote_attr(self): + """The 'remote' :class:`.MapperProperty` referenced by this + :class:`.AssociationProxy`. + + .. versionadded:: 0.7.3 + + See also: + + :attr:`.AssociationProxy.attr` + + :attr:`.AssociationProxy.local_attr` + + """ + return getattr(self.target_class, self.value_attr) + + @property + def local_attr(self): + """The 'local' :class:`.MapperProperty` referenced by this + :class:`.AssociationProxy`. + + .. versionadded:: 0.7.3 + + See also: + + :attr:`.AssociationProxy.attr` + + :attr:`.AssociationProxy.remote_attr` + + """ + return getattr(self.owning_class, self.target_collection) + + @property + def attr(self): + """Return a tuple of ``(local_attr, remote_attr)``. + + This attribute is convenient when specifying a join + using :meth:`.Query.join` across two relationships:: + + sess.query(Parent).join(*Parent.proxied.attr) + + .. versionadded:: 0.7.3 + + See also: + + :attr:`.AssociationProxy.local_attr` + + :attr:`.AssociationProxy.remote_attr` + + """ + return (self.local_attr, self.remote_attr) def _get_property(self): return (orm.class_mapper(self.owning_class). get_property(self.target_collection)) - @property + @util.memoized_property def target_class(self): - """The class the proxy is attached to.""" + """The intermediary class handled by this :class:`.AssociationProxy`. + + Intercepted append/set/assignment events will result + in the generation of new instances of this class. + + """ return self._get_property().mapper.class_ - def _target_is_scalar(self): - return not self._get_property().uselist + @util.memoized_property + def scalar(self): + """Return ``True`` if this :class:`.AssociationProxy` proxies a scalar + relationship on the local side.""" + + scalar = not self._get_property().uselist + if scalar: + self._initialize_scalar_accessors() + return scalar + + @util.memoized_property + def _value_is_scalar(self): + return not self._get_property().\ + mapper.get_property(self.value_attr).uselist + + @util.memoized_property + def _target_is_object(self): + return getattr(self.target_class, self.value_attr).impl.uses_objects def __get__(self, obj, class_): if self.owning_class is None: self.owning_class = class_ and class_ or type(obj) if obj is None: return self - elif self.scalar is None: - self.scalar = self._target_is_scalar() - if self.scalar: - self._initialize_scalar_accessors() if self.scalar: - return self._scalar_get(getattr(obj, self.target_collection)) + target = getattr(obj, self.target_collection) + return self._scalar_get(target) else: try: # If the owning instance is reborn (orm session resurrect, @@ -173,14 +264,10 @@ class AssociationProxy(object): proxy = self._new(_lazy_collection(obj, self.target_collection)) setattr(obj, self.key, (id(obj), proxy)) return proxy - + def __set__(self, obj, values): if self.owning_class is None: self.owning_class = type(obj) - if self.scalar is None: - self.scalar = self._target_is_scalar() - if self.scalar: - self._initialize_scalar_accessors() if self.scalar: creator = self.creator and self.creator or self.target_class @@ -209,7 +296,8 @@ class AssociationProxy(object): def _default_getset(self, collection_class): attr = self.value_attr - getter = operator.attrgetter(attr) + _getter = operator.attrgetter(attr) + getter = lambda target: _getter(target) if target is not None else None if collection_class is dict: setter = lambda o, k, v: setattr(o, attr, v) else: @@ -221,21 +309,25 @@ class AssociationProxy(object): self.collection_class = util.duck_type_collection(lazy_collection()) if self.proxy_factory: - return self.proxy_factory(lazy_collection, creator, self.value_attr, self) + return self.proxy_factory( + lazy_collection, creator, self.value_attr, self) if self.getset_factory: getter, setter = self.getset_factory(self.collection_class, self) else: getter, setter = self._default_getset(self.collection_class) - + if self.collection_class is list: - return _AssociationList(lazy_collection, creator, getter, setter, self) + return _AssociationList( + lazy_collection, creator, getter, setter, self) elif self.collection_class is dict: - return _AssociationDict(lazy_collection, creator, getter, setter, self) + return _AssociationDict( + lazy_collection, creator, getter, setter, self) elif self.collection_class is set: - return _AssociationSet(lazy_collection, creator, getter, setter, self) + return _AssociationSet( + lazy_collection, creator, getter, setter, self) else: - raise exceptions.ArgumentError( + raise exc.ArgumentError( 'could not guess which interface to use for ' 'collection_class "%s" backing "%s"; specify a ' 'proxy_factory and proxy_bulk_set manually' % @@ -248,7 +340,7 @@ class AssociationProxy(object): getter, setter = self.getset_factory(self.collection_class, self) else: getter, setter = self._default_getset(self.collection_class) - + proxy.creator = creator proxy.getter = getter proxy.setter = setter @@ -263,28 +355,102 @@ class AssociationProxy(object): elif self.collection_class is set: proxy.update(values) else: - raise exceptions.ArgumentError( - 'no proxy_bulk_set supplied for custom ' - 'collection_class implementation') + raise exc.ArgumentError( + 'no proxy_bulk_set supplied for custom ' + 'collection_class implementation') @property def _comparator(self): return self._get_property().comparator def any(self, criterion=None, **kwargs): - return self._comparator.any(getattr(self.target_class, self.value_attr).has(criterion, **kwargs)) - + """Produce a proxied 'any' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.RelationshipProperty.Comparator.any` + and/or :meth:`.RelationshipProperty.Comparator.has` + operators of the underlying proxied attributes. + + """ + if self._target_is_object: + if self._value_is_scalar: + value_expr = getattr( + self.target_class, self.value_attr).has( + criterion, **kwargs) + else: + value_expr = getattr( + self.target_class, self.value_attr).any( + criterion, **kwargs) + else: + value_expr = criterion + + # check _value_is_scalar here, otherwise + # we're scalar->scalar - call .any() so that + # the "can't call any() on a scalar" msg is raised. + if self.scalar and not self._value_is_scalar: + return self._comparator.has( + value_expr + ) + else: + return self._comparator.any( + value_expr + ) + def has(self, criterion=None, **kwargs): - return self._comparator.has(getattr(self.target_class, self.value_attr).has(criterion, **kwargs)) + """Produce a proxied 'has' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.RelationshipProperty.Comparator.any` + and/or :meth:`.RelationshipProperty.Comparator.has` + operators of the underlying proxied attributes. + + """ + + if self._target_is_object: + return self._comparator.has( + getattr(self.target_class, self.value_attr). + has(criterion, **kwargs) + ) + else: + if criterion is not None or kwargs: + raise exc.ArgumentError( + "Non-empty has() not allowed for " + "column-targeted association proxy; use ==") + return self._comparator.has() def contains(self, obj): - return self._comparator.any(**{self.value_attr: obj}) + """Produce a proxied 'contains' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.RelationshipProperty.Comparator.any` + , :meth:`.RelationshipProperty.Comparator.has`, + and/or :meth:`.RelationshipProperty.Comparator.contains` + operators of the underlying proxied attributes. + """ + + if self.scalar and not self._value_is_scalar: + return self._comparator.has( + getattr(self.target_class, self.value_attr).contains(obj) + ) + else: + return self._comparator.any(**{self.value_attr: obj}) def __eq__(self, obj): - return self._comparator.has(**{self.value_attr: obj}) + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + if obj is None: + return or_( + self._comparator.has(**{self.value_attr: obj}), + self._comparator == None + ) + else: + return self._comparator.has(**{self.value_attr: obj}) def __ne__(self, obj): - return not_(self.__eq__(obj)) + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + return self._comparator.has( + getattr(self.target_class, self.value_attr) != obj) class _lazy_collection(object): @@ -295,22 +461,23 @@ class _lazy_collection(object): def __call__(self): obj = self.ref() if obj is None: - raise exceptions.InvalidRequestError( - "stale association proxy, parent object has gone out of " - "scope") + raise exc.InvalidRequestError( + "stale association proxy, parent object has gone out of " + "scope") return getattr(obj, self.target) def __getstate__(self): - return {'obj':self.ref(), 'target':self.target} - + return {'obj': self.ref(), 'target': self.target} + def __setstate__(self, state): self.ref = weakref.ref(state['obj']) self.target = state['target'] + class _AssociationCollection(object): def __init__(self, lazy_collection, creator, getter, setter, parent): - """Constructs an _AssociationCollection. - + """Constructs an _AssociationCollection. + This will always be a subclass of either _AssociationList, _AssociationSet, or _AssociationDict. @@ -344,17 +511,20 @@ class _AssociationCollection(object): def __len__(self): return len(self.col) - def __nonzero__(self): + def __bool__(self): return bool(self.col) + __nonzero__ = __bool__ + def __getstate__(self): - return {'parent':self.parent, 'lazy_collection':self.lazy_collection} + return {'parent': self.parent, 'lazy_collection': self.lazy_collection} def __setstate__(self, state): self.parent = state['parent'] self.lazy_collection = state['lazy_collection'] self.parent._inflate(self) - + + class _AssociationList(_AssociationCollection): """Generic, converting, list-to-list proxy.""" @@ -368,7 +538,10 @@ class _AssociationList(_AssociationCollection): return self.setter(object, value) def __getitem__(self, index): - return self._get(self.col[index]) + if not isinstance(index, slice): + return self._get(self.col[index]) + else: + return [self._get(member) for member in self.col[index]] def __setitem__(self, index, value): if not isinstance(index, slice): @@ -382,11 +555,12 @@ class _AssociationList(_AssociationCollection): stop = index.stop step = index.step or 1 - rng = range(index.start or 0, stop, step) + start = index.start or 0 + rng = list(range(index.start or 0, stop, step)) if step == 1: for i in rng: - del self[index.start] - i = index.start + del self[start] + i = start for item in value: self.insert(i, item) i += 1 @@ -429,7 +603,7 @@ class _AssociationList(_AssociationCollection): for member in self.col: yield self._get(member) - raise StopIteration + return def append(self, value): item = self._create(value) @@ -437,7 +611,7 @@ class _AssociationList(_AssociationCollection): def count(self, value): return sum([1 for _ in - itertools.ifilter(lambda v: v == value, iter(self))]) + util.itertools_filter(lambda v: v == value, iter(self))]) def extend(self, values): for v in values: @@ -536,14 +710,16 @@ class _AssociationList(_AssociationCollection): def __hash__(self): raise TypeError("%s objects are unhashable" % type(self).__name__) - for func_name, func in locals().items(): - if (util.callable(func) and func.func_name == func_name and - not func.__doc__ and hasattr(list, func_name)): + for func_name, func in list(locals().items()): + if (util.callable(func) and func.__name__ == func_name and + not func.__doc__ and hasattr(list, func_name)): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func _NotProvided = util.symbol('_NotProvided') + + class _AssociationDict(_AssociationCollection): """Generic, converting, dict-to-dict proxy.""" @@ -577,7 +753,7 @@ class _AssociationDict(_AssociationCollection): return key in self.col def __iter__(self): - return self.col.iterkeys() + return iter(self.col.keys()) def clear(self): self.col.clear() @@ -622,24 +798,27 @@ class _AssociationDict(_AssociationCollection): def keys(self): return self.col.keys() - def iterkeys(self): - return self.col.iterkeys() + if util.py2k: + def iteritems(self): + return ((key, self._get(self.col[key])) for key in self.col) - def values(self): - return [ self._get(member) for member in self.col.values() ] + def itervalues(self): + return (self._get(self.col[key]) for key in self.col) - def itervalues(self): - for key in self.col: - yield self._get(self.col[key]) - raise StopIteration + def iterkeys(self): + return self.col.iterkeys() - def items(self): - return [(k, self._get(self.col[k])) for k in self] + def values(self): + return [self._get(member) for member in self.col.values()] - def iteritems(self): - for key in self.col: - yield (key, self._get(self.col[key])) - raise StopIteration + def items(self): + return [(k, self._get(self.col[k])) for k in self] + else: + def items(self): + return ((key, self._get(self.col[key])) for key in self.col) + + def values(self): + return (self._get(self.col[key]) for key in self.col) def pop(self, key, default=_NotProvided): if default is _NotProvided: @@ -658,11 +837,20 @@ class _AssociationDict(_AssociationCollection): len(a)) elif len(a) == 1: seq_or_map = a[0] - for item in seq_or_map: - if isinstance(item, tuple): - self[item[0]] = item[1] - else: + # discern dict from sequence - took the advice from + # http://www.voidspace.org.uk/python/articles/duck_typing.shtml + # still not perfect :( + if hasattr(seq_or_map, 'keys'): + for item in seq_or_map: self[item] = seq_or_map[item] + else: + try: + for k, v in seq_or_map: + self[k] = v + except ValueError: + raise ValueError( + "dictionary update sequence " + "requires 2-element tuples") for key, value in kw: self[key] = value @@ -673,9 +861,9 @@ class _AssociationDict(_AssociationCollection): def __hash__(self): raise TypeError("%s objects are unhashable" % type(self).__name__) - for func_name, func in locals().items(): - if (util.callable(func) and func.func_name == func_name and - not func.__doc__ and hasattr(dict, func_name)): + for func_name, func in list(locals().items()): + if (util.callable(func) and func.__name__ == func_name and + not func.__doc__ and hasattr(dict, func_name)): func.__doc__ = getattr(dict, func_name).__doc__ del func_name, func @@ -695,12 +883,14 @@ class _AssociationSet(_AssociationCollection): def __len__(self): return len(self.col) - def __nonzero__(self): + def __bool__(self): if self.col: return True else: return False + __nonzero__ = __bool__ + def __contains__(self, value): for member in self.col: # testlib.pragma exempt:__eq__ @@ -717,7 +907,7 @@ class _AssociationSet(_AssociationCollection): """ for member in self.col: yield self._get(member) - raise StopIteration + return def add(self, value): if value not in self: @@ -871,8 +1061,8 @@ class _AssociationSet(_AssociationCollection): def __hash__(self): raise TypeError("%s objects are unhashable" % type(self).__name__) - for func_name, func in locals().items(): - if (util.callable(func) and func.func_name == func_name and - not func.__doc__ and hasattr(set, func_name)): + for func_name, func in list(locals().items()): + if (util.callable(func) and func.__name__ == func_name and + not func.__doc__ and hasattr(set, func_name)): func.__doc__ = getattr(set, func_name).__doc__ del func_name, func diff --git a/sqlalchemy/ext/compiler.py b/sqlalchemy/ext/compiler.py index 3226b0e..8b2bc95 100644 --- a/sqlalchemy/ext/compiler.py +++ b/sqlalchemy/ext/compiler.py @@ -1,31 +1,39 @@ -"""Provides an API for creation of custom ClauseElements and compilers. +# ext/compiler.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +r"""Provides an API for creation of custom ClauseElements and compilers. Synopsis ======== -Usage involves the creation of one or more :class:`~sqlalchemy.sql.expression.ClauseElement` -subclasses and one or more callables defining its compilation:: +Usage involves the creation of one or more +:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or +more callables defining its compilation:: from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.expression import ColumnClause - + class MyColumn(ColumnClause): pass - + @compiles(MyColumn) def compile_mycolumn(element, compiler, **kw): return "[%s]" % element.name - + Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`, the base expression element for named column objects. The ``compiles`` decorator registers itself with the ``MyColumn`` class so that it is invoked when the object is compiled to a string:: from sqlalchemy import select - + s = select([MyColumn('x'), MyColumn('y')]) print str(s) - + Produces:: SELECT [x], [y] @@ -50,22 +58,25 @@ invoked for the dialect in use:: @compiles(AlterColumn, 'postgresql') def visit_alter_column(element, compiler, **kw): - return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, element.column.name) + return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, + element.column.name) -The second ``visit_alter_table`` will be invoked when any ``postgresql`` dialect is used. +The second ``visit_alter_table`` will be invoked when any ``postgresql`` +dialect is used. Compiling sub-elements of a custom expression construct ======================================================= -The ``compiler`` argument is the :class:`~sqlalchemy.engine.base.Compiled` -object in use. This object can be inspected for any information about the -in-progress compilation, including ``compiler.dialect``, -``compiler.statement`` etc. The :class:`~sqlalchemy.sql.compiler.SQLCompiler` -and :class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()`` +The ``compiler`` argument is the +:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object +can be inspected for any information about the in-progress compilation, +including ``compiler.dialect``, ``compiler.statement`` etc. The +:class:`~sqlalchemy.sql.compiler.SQLCompiler` and +:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()`` method which can be used for compilation of embedded attributes:: from sqlalchemy.sql.expression import Executable, ClauseElement - + class InsertFromSelect(Executable, ClauseElement): def __init__(self, table, select): self.table = table @@ -80,36 +91,110 @@ method which can be used for compilation of embedded attributes:: insert = InsertFromSelect(t1, select([t1]).where(t1.c.x>5)) print insert - + Produces:: - "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)" + "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z + FROM mytable WHERE mytable.x > :x_1)" + +.. note:: + + The above ``InsertFromSelect`` construct is only an example, this actual + functionality is already available using the + :meth:`.Insert.from_select` method. + +.. note:: + + The above ``InsertFromSelect`` construct probably wants to have "autocommit" + enabled. See :ref:`enabling_compiled_autocommit` for this step. Cross Compiling between SQL and DDL compilers --------------------------------------------- -SQL and DDL constructs are each compiled using different base compilers - ``SQLCompiler`` -and ``DDLCompiler``. A common need is to access the compilation rules of SQL expressions -from within a DDL expression. The ``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as below where we generate a CHECK -constraint that embeds a SQL expression:: +SQL and DDL constructs are each compiled using different base compilers - +``SQLCompiler`` and ``DDLCompiler``. A common need is to access the +compilation rules of SQL expressions from within a DDL expression. The +``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as +below where we generate a CHECK constraint that embeds a SQL expression:: @compiles(MyConstraint) def compile_my_constraint(constraint, ddlcompiler, **kw): return "CONSTRAINT %s CHECK (%s)" % ( constraint.name, - ddlcompiler.sql_compiler.process(constraint.expression) + ddlcompiler.sql_compiler.process( + constraint.expression, literal_binds=True) ) +Above, we add an additional flag to the process step as called by +:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This +indicates that any SQL expression which refers to a :class:`.BindParameter` +object or other "literal" object such as those which refer to strings or +integers should be rendered **in-place**, rather than being referred to as +a bound parameter; when emitting DDL, bound parameters are typically not +supported. + + +.. _enabling_compiled_autocommit: + +Enabling Autocommit on a Construct +================================== + +Recall from the section :ref:`autocommit` that the :class:`.Engine`, when +asked to execute a construct in the absence of a user-defined transaction, +detects if the given construct represents DML or DDL, that is, a data +modification or data definition statement, which requires (or may require, +in the case of DDL) that the transaction generated by the DBAPI be committed +(recall that DBAPI always has a transaction going on regardless of what +SQLAlchemy does). Checking for this is actually accomplished by checking for +the "autocommit" execution option on the construct. When building a +construct like an INSERT derivation, a new DDL type, or perhaps a stored +procedure that alters data, the "autocommit" option needs to be set in order +for the statement to function with "connectionless" execution +(as described in :ref:`dbengine_implicit`). + +Currently a quick way to do this is to subclass :class:`.Executable`, then +add the "autocommit" flag to the ``_execution_options`` dictionary (note this +is a "frozen" dictionary which supplies a generative ``union()`` method):: + + from sqlalchemy.sql.expression import Executable, ClauseElement + + class MyInsertThing(Executable, ClauseElement): + _execution_options = \ + Executable._execution_options.union({'autocommit': True}) + +More succinctly, if the construct is truly similar to an INSERT, UPDATE, or +DELETE, :class:`.UpdateBase` can be used, which already is a subclass +of :class:`.Executable`, :class:`.ClauseElement` and includes the +``autocommit`` flag:: + + from sqlalchemy.sql.expression import UpdateBase + + class MyInsertThing(UpdateBase): + def __init__(self, ...): + ... + + + + +DDL elements that subclass :class:`.DDLElement` already have the +"autocommit" flag turned on. + + + + Changing the default compilation of existing constructs ======================================================= -The compiler extension applies just as well to the existing constructs. When overriding -the compilation of a built in SQL construct, the @compiles decorator is invoked upon -the appropriate class (be sure to use the class, i.e. ``Insert`` or ``Select``, instead of the creation function such as ``insert()`` or ``select()``). +The compiler extension applies just as well to the existing constructs. When +overriding the compilation of a built in SQL construct, the @compiles +decorator is invoked upon the appropriate class (be sure to use the class, +i.e. ``Insert`` or ``Select``, instead of the creation function such +as ``insert()`` or ``select()``). -Within the new compilation function, to get at the "original" compilation routine, -use the appropriate visit_XXX method - this because compiler.process() will call upon the -overriding routine and cause an endless loop. Such as, to add "prefix" to all insert statements:: +Within the new compilation function, to get at the "original" compilation +routine, use the appropriate visit_XXX method - this +because compiler.process() will call upon the overriding routine and cause +an endless loop. Such as, to add "prefix" to all insert statements:: from sqlalchemy.sql.expression import Insert @@ -117,38 +202,77 @@ overriding routine and cause an endless loop. Such as, to add "prefix" to all def prefix_inserts(insert, compiler, **kw): return compiler.visit_insert(insert.prefix_with("some prefix"), **kw) -The above compiler will prefix all INSERT statements with "some prefix" when compiled. +The above compiler will prefix all INSERT statements with "some prefix" when +compiled. + +.. _type_compilation_extension: + +Changing Compilation of Types +============================= + +``compiler`` works for types, too, such as below where we implement the +MS-SQL specific 'max' keyword for ``String``/``VARCHAR``:: + + @compiles(String, 'mssql') + @compiles(VARCHAR, 'mssql') + def compile_varchar(element, compiler, **kw): + if element.length == 'max': + return "VARCHAR('max')" + else: + return compiler.visit_VARCHAR(element, **kw) + + foo = Table('foo', metadata, + Column('data', VARCHAR('max')) + ) Subclassing Guidelines ====================== -A big part of using the compiler extension is subclassing SQLAlchemy expression constructs. To make this easier, the expression and schema packages feature a set of "bases" intended for common tasks. A synopsis is as follows: +A big part of using the compiler extension is subclassing SQLAlchemy +expression constructs. To make this easier, the expression and +schema packages feature a set of "bases" intended for common tasks. +A synopsis is as follows: * :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root expression class. Any SQL expression can be derived from this base, and is probably the best choice for longer constructs such as specialized INSERT statements. - + * :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all "column-like" elements. Anything that you'd place in the "columns" clause of a SELECT statement (as well as order by and group by) can derive from this - the object will automatically have Python "comparison" behavior. - + :class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a ``type`` member which is expression's return type. This can be established at the instance level in the constructor, or at the class level if its generally constant:: - + class timestamp(ColumnElement): type = TIMESTAMP() - -* :class:`~sqlalchemy.sql.expression.FunctionElement` - This is a hybrid of a + +* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a ``ColumnElement`` and a "from clause" like object, and represents a SQL function or stored procedure type of call. Since most databases support statements along the line of "SELECT FROM " ``FunctionElement`` adds in the ability to be used in the FROM clause of a - ``select()`` construct. - + ``select()`` construct:: + + from sqlalchemy.sql.expression import FunctionElement + + class coalesce(FunctionElement): + name = 'coalesce' + + @compiles(coalesce) + def compile(element, compiler, **kw): + return "coalesce(%s)" % compiler.process(element.clauses) + + @compiles(coalesce, 'oracle') + def compile(element, compiler, **kw): + if len(element.clauses) > 2: + raise TypeError("coalesce only supports two arguments on Oracle") + return "nvl(%s)" % compiler.process(element.clauses) + * :class:`~sqlalchemy.schema.DDLElement` - The root of all DDL expressions, like CREATE TABLE, ALTER TABLE, etc. Compilation of ``DDLElement`` subclasses is issued by a ``DDLCompiler`` instead of a ``SQLCompiler``. @@ -156,39 +280,195 @@ A big part of using the compiler extension is subclassing SQLAlchemy expression ``execute_at()`` method, allowing the construct to be invoked during CREATE TABLE and DROP TABLE sequences. -* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which should be - used with any expression class that represents a "standalone" SQL statement that - can be passed directly to an ``execute()`` method. It is already implicit - within ``DDLElement`` and ``FunctionElement``. - +* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which + should be used with any expression class that represents a "standalone" + SQL statement that can be passed directly to an ``execute()`` method. It + is already implicit within ``DDLElement`` and ``FunctionElement``. + +Further Examples +================ + +"UTC timestamp" function +------------------------- + +A function that works like "CURRENT_TIMESTAMP" except applies the +appropriate conversions so that the time is in UTC time. Timestamps are best +stored in relational databases as UTC, without time zones. UTC so that your +database doesn't think time has gone backwards in the hour when daylight +savings ends, without timezones because timezones are like character +encodings - they're best applied only at the endpoints of an application +(i.e. convert to UTC upon user input, re-apply desired timezone upon display). + +For PostgreSQL and Microsoft SQL Server:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.types import DateTime + + class utcnow(expression.FunctionElement): + type = DateTime() + + @compiles(utcnow, 'postgresql') + def pg_utcnow(element, compiler, **kw): + return "TIMEZONE('utc', CURRENT_TIMESTAMP)" + + @compiles(utcnow, 'mssql') + def ms_utcnow(element, compiler, **kw): + return "GETUTCDATE()" + +Example usage:: + + from sqlalchemy import ( + Table, Column, Integer, String, DateTime, MetaData + ) + metadata = MetaData() + event = Table("event", metadata, + Column("id", Integer, primary_key=True), + Column("description", String(50), nullable=False), + Column("timestamp", DateTime, server_default=utcnow()) + ) + +"GREATEST" function +------------------- + +The "GREATEST" function is given any number of arguments and returns the one +that is of the highest value - its equivalent to Python's ``max`` +function. A SQL standard version versus a CASE based version which only +accommodates two arguments:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.types import Numeric + + class greatest(expression.FunctionElement): + type = Numeric() + name = 'greatest' + + @compiles(greatest) + def default_greatest(element, compiler, **kw): + return compiler.visit_function(element) + + @compiles(greatest, 'sqlite') + @compiles(greatest, 'mssql') + @compiles(greatest, 'oracle') + def case_greatest(element, compiler, **kw): + arg1, arg2 = list(element.clauses) + return "CASE WHEN %s > %s THEN %s ELSE %s END" % ( + compiler.process(arg1), + compiler.process(arg2), + compiler.process(arg1), + compiler.process(arg2), + ) + +Example usage:: + + Session.query(Account).\ + filter( + greatest( + Account.checking_balance, + Account.savings_balance) > 10000 + ) + +"false" expression +------------------ + +Render a "false" constant expression, rendering as "0" on platforms that +don't have a "false" constant:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + + class sql_false(expression.ColumnElement): + pass + + @compiles(sql_false) + def default_false(element, compiler, **kw): + return "false" + + @compiles(sql_false, 'mssql') + @compiles(sql_false, 'mysql') + @compiles(sql_false, 'oracle') + def int_false(element, compiler, **kw): + return "0" + +Example usage:: + + from sqlalchemy import select, union_all + + exp = union_all( + select([users.c.name, sql_false().label("enrolled")]), + select([customers.c.name, customers.c.enrolled]) + ) + """ +from .. import exc +from ..sql import visitors + def compiles(class_, *specs): + """Register a function as a compiler for a + given :class:`.ClauseElement` type.""" + def decorate(fn): - existing = getattr(class_, '_compiler_dispatcher', None) + # get an existing @compiles handler + existing = class_.__dict__.get('_compiler_dispatcher', None) + + # get the original handler. All ClauseElement classes have one + # of these, but some TypeEngine classes will not. + existing_dispatch = getattr(class_, '_compiler_dispatch', None) + if not existing: existing = _dispatcher() + if existing_dispatch: + def _wrap_existing_dispatch(element, compiler, **kw): + try: + return existing_dispatch(element, compiler, **kw) + except exc.UnsupportedCompilationError: + raise exc.CompileError( + "%s construct has no default " + "compilation handler." % type(element)) + existing.specs['default'] = _wrap_existing_dispatch + # TODO: why is the lambda needed ? - setattr(class_, '_compiler_dispatch', lambda *arg, **kw: existing(*arg, **kw)) + setattr(class_, '_compiler_dispatch', + lambda *arg, **kw: existing(*arg, **kw)) setattr(class_, '_compiler_dispatcher', existing) - + if specs: for s in specs: existing.specs[s] = fn + else: existing.specs['default'] = fn return fn return decorate - + + +def deregister(class_): + """Remove all custom compilers associated with a given + :class:`.ClauseElement` type.""" + + if hasattr(class_, '_compiler_dispatcher'): + # regenerate default _compiler_dispatch + visitors._generate_dispatch(class_) + # remove custom directive + del class_._compiler_dispatcher + + class _dispatcher(object): def __init__(self): self.specs = {} - + def __call__(self, element, compiler, **kw): # TODO: yes, this could also switch off of DBAPI in use. fn = self.specs.get(compiler.dialect.name, None) if not fn: - fn = self.specs['default'] + try: + fn = self.specs['default'] + except KeyError: + raise exc.CompileError( + "%s construct has no default " + "compilation handler." % type(element)) + return fn(element, compiler, **kw) - diff --git a/sqlalchemy/ext/horizontal_shard.py b/sqlalchemy/ext/horizontal_shard.py index 78e3f59..d20fbd4 100644 --- a/sqlalchemy/ext/horizontal_shard.py +++ b/sqlalchemy/ext/horizontal_shard.py @@ -1,5 +1,6 @@ -# horizontal_shard.py -# Copyright (C) the SQLAlchemy authors and contributors +# ext/horizontal_shard.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -9,105 +10,54 @@ Defines a rudimental 'horizontal sharding' system which allows a Session to distribute queries and persistence operations across multiple databases. -For a usage example, see the :ref:`examples_sharding` example included in -the source distrbution. +For a usage example, see the :ref:`examples_sharding` example included in +the source distribution. """ -import sqlalchemy.exceptions as sa_exc -from sqlalchemy import util -from sqlalchemy.orm.session import Session -from sqlalchemy.orm.query import Query +from .. import util +from ..orm.session import Session +from ..orm.query import Query __all__ = ['ShardedSession', 'ShardedQuery'] -class ShardedSession(Session): - def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs): - """Construct a ShardedSession. - - :param shard_chooser: A callable which, passed a Mapper, a mapped instance, and possibly a - SQL clause, returns a shard ID. This id may be based off of the - attributes present within the object, or on some round-robin - scheme. If the scheme is based on a selection, it should set - whatever state on the instance to mark it in the future as - participating in that shard. - - :param id_chooser: A callable, passed a query and a tuple of identity values, which - should return a list of shard ids where the ID might reside. The - databases will be queried in the order of this listing. - - :param query_chooser: For a given Query, returns the list of shard_ids where the query - should be issued. Results from all shards returned will be combined - together into a single listing. - - :param shards: A dictionary of string shard names to :class:`~sqlalchemy.engine.base.Engine` - objects. - - """ - super(ShardedSession, self).__init__(**kwargs) - self.shard_chooser = shard_chooser - self.id_chooser = id_chooser - self.query_chooser = query_chooser - self.__binds = {} - self._mapper_flush_opts = {'connection_callable':self.connection} - self._query_cls = ShardedQuery - if shards is not None: - for k in shards: - self.bind_shard(k, shards[k]) - - def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): - if shard_id is None: - shard_id = self.shard_chooser(mapper, instance) - - if self.transaction is not None: - return self.transaction.connection(mapper, shard_id=shard_id) - else: - return self.get_bind(mapper, - shard_id=shard_id, - instance=instance).contextual_connect(**kwargs) - - def get_bind(self, mapper, shard_id=None, instance=None, clause=None, **kw): - if shard_id is None: - shard_id = self.shard_chooser(mapper, instance, clause=clause) - return self.__binds[shard_id] - - def bind_shard(self, shard_id, bind): - self.__binds[shard_id] = bind - class ShardedQuery(Query): def __init__(self, *args, **kwargs): super(ShardedQuery, self).__init__(*args, **kwargs) self.id_chooser = self.session.id_chooser self.query_chooser = self.session.query_chooser self._shard_id = None - + def set_shard(self, shard_id): """return a new query, limited to a single shard ID. - - all subsequent operations with the returned query will + + all subsequent operations with the returned query will be against the single shard regardless of other state. """ - + q = self._clone() q._shard_id = shard_id return q - + def _execute_and_instances(self, context): - if self._shard_id is not None: - result = self.session.connection( - mapper=self._mapper_zero(), - shard_id=self._shard_id).execute(context.statement, self._params) + def iter_for_shard(shard_id): + context.attributes['shard_id'] = shard_id + result = self._connection_from_session( + mapper=self._mapper_zero(), + shard_id=shard_id).execute( + context.statement, + self._params) return self.instances(result, context) + + if self._shard_id is not None: + return iter_for_shard(self._shard_id) else: partial = [] for shard_id in self.query_chooser(self): - result = self.session.connection( - mapper=self._mapper_zero(), - shard_id=shard_id).execute(context.statement, self._params) - partial = partial + list(self.instances(result, context)) - - # if some kind of in memory 'sorting' + partial.extend(iter_for_shard(shard_id)) + + # if some kind of in memory 'sorting' # were done, this is where it would happen return iter(partial) @@ -122,4 +72,60 @@ class ShardedQuery(Query): return o else: return None - + + +class ShardedSession(Session): + def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, + query_cls=ShardedQuery, **kwargs): + """Construct a ShardedSession. + + :param shard_chooser: A callable which, passed a Mapper, a mapped + instance, and possibly a SQL clause, returns a shard ID. This id + may be based off of the attributes present within the object, or on + some round-robin scheme. If the scheme is based on a selection, it + should set whatever state on the instance to mark it in the future as + participating in that shard. + + :param id_chooser: A callable, passed a query and a tuple of identity + values, which should return a list of shard ids where the ID might + reside. The databases will be queried in the order of this listing. + + :param query_chooser: For a given Query, returns the list of shard_ids + where the query should be issued. Results from all shards returned + will be combined together into a single listing. + + :param shards: A dictionary of string shard names + to :class:`~sqlalchemy.engine.Engine` objects. + + """ + super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs) + self.shard_chooser = shard_chooser + self.id_chooser = id_chooser + self.query_chooser = query_chooser + self.__binds = {} + self.connection_callable = self.connection + if shards is not None: + for k in shards: + self.bind_shard(k, shards[k]) + + def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): + if shard_id is None: + shard_id = self.shard_chooser(mapper, instance) + + if self.transaction is not None: + return self.transaction.connection(mapper, shard_id=shard_id) + else: + return self.get_bind( + mapper, + shard_id=shard_id, + instance=instance + ).contextual_connect(**kwargs) + + def get_bind(self, mapper, shard_id=None, + instance=None, clause=None, **kw): + if shard_id is None: + shard_id = self.shard_chooser(mapper, instance, clause=clause) + return self.__binds[shard_id] + + def bind_shard(self, shard_id, bind): + self.__binds[shard_id] = bind diff --git a/sqlalchemy/ext/orderinglist.py b/sqlalchemy/ext/orderinglist.py index 0d2c3ae..6b22aa6 100644 --- a/sqlalchemy/ext/orderinglist.py +++ b/sqlalchemy/ext/orderinglist.py @@ -1,58 +1,78 @@ -"""A custom list that manages index/position information for its children. +# ext/orderinglist.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""A custom list that manages index/position information for contained +elements. :author: Jason Kirtland -``orderinglist`` is a helper for mutable ordered relationships. It will intercept -list operations performed on a relationship collection and automatically -synchronize changes in list position with an attribute on the related objects. -(See :ref:`advdatamapping_entitycollections` for more information on the general pattern.) +``orderinglist`` is a helper for mutable ordered relationships. It will +intercept list operations performed on a :func:`.relationship`-managed +collection and +automatically synchronize changes in list position onto a target scalar +attribute. -Example: Two tables that store slides in a presentation. Each slide -has a number of bullet points, displayed in order by the 'position' -column on the bullets table. These bullets can be inserted and re-ordered -by your end users, and you need to update the 'position' column of all -affected rows when changes are made. +Example: A ``slide`` table, where each row refers to zero or more entries +in a related ``bullet`` table. The bullets within a slide are +displayed in order based on the value of the ``position`` column in the +``bullet`` table. As entries are reordered in memory, the value of the +``position`` attribute should be updated to reflect the new sort order:: -.. sourcecode:: python+sql - slides_table = Table('Slides', metadata, - Column('id', Integer, primary_key=True), - Column('name', String)) + Base = declarative_base() - bullets_table = Table('Bullets', metadata, - Column('id', Integer, primary_key=True), - Column('slide_id', Integer, ForeignKey('Slides.id')), - Column('position', Integer), - Column('text', String)) + class Slide(Base): + __tablename__ = 'slide' - class Slide(object): - pass - class Bullet(object): - pass + id = Column(Integer, primary_key=True) + name = Column(String) - mapper(Slide, slides_table, properties={ - 'bullets': relationship(Bullet, order_by=[bullets_table.c.position]) - }) - mapper(Bullet, bullets_table) + bullets = relationship("Bullet", order_by="Bullet.position") -The standard relationship mapping will produce a list-like attribute on each Slide -containing all related Bullets, but coping with changes in ordering is totally -your responsibility. If you insert a Bullet into that list, there is no -magic- it won't have a position attribute unless you assign it it one, and -you'll need to manually renumber all the subsequent Bullets in the list to -accommodate the insert. + class Bullet(Base): + __tablename__ = 'bullet' + id = Column(Integer, primary_key=True) + slide_id = Column(Integer, ForeignKey('slide.id')) + position = Column(Integer) + text = Column(String) -An ``orderinglist`` can automate this and manage the 'position' attribute on all -related bullets for you. +The standard relationship mapping will produce a list-like attribute on each +``Slide`` containing all related ``Bullet`` objects, +but coping with changes in ordering is not handled automatically. +When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position`` +attribute will remain unset until manually assigned. When the ``Bullet`` +is inserted into the middle of the list, the following ``Bullet`` objects +will also need to be renumbered. -.. sourcecode:: python+sql - - mapper(Slide, slides_table, properties={ - 'bullets': relationship(Bullet, - collection_class=ordering_list('position'), - order_by=[bullets_table.c.position]) - }) - mapper(Bullet, bullets_table) +The :class:`.OrderingList` object automates this task, managing the +``position`` attribute on all ``Bullet`` objects in the collection. It is +constructed using the :func:`.ordering_list` factory:: + + from sqlalchemy.ext.orderinglist import ordering_list + + Base = declarative_base() + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) + + class Bullet(Base): + __tablename__ = 'bullet' + id = Column(Integer, primary_key=True) + slide_id = Column(Integer, ForeignKey('slide.id')) + position = Column(Integer) + text = Column(String) + +With the above mapping the ``Bullet.position`` attribute is managed:: s = Slide() s.bullets.append(Bullet()) @@ -63,71 +83,98 @@ related bullets for you. s.bullets[2].position >>> 2 -Use the ``ordering_list`` function to set up the ``collection_class`` on relationships -(as in the mapper example above). This implementation depends on the list -starting in the proper order, so be SURE to put an order_by on your relationship. +The :class:`.OrderingList` construct only works with **changes** to a +collection, and not the initial load from the database, and requires that the +list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the +:func:`.relationship` against the target ordering attribute, so that the +ordering is correct when first loaded. -.. warning:: ``ordering_list`` only provides limited functionality when a primary - key column or unique column is the target of the sort. Since changing the order of - entries often means that two rows must trade values, this is not possible when - the value is constrained by a primary key or unique constraint, since one of the rows - would temporarily have to point to a third available value so that the other row - could take its old value. ``ordering_list`` doesn't do any of this for you, - nor does SQLAlchemy itself. +.. warning:: -``ordering_list`` takes the name of the related object's ordering attribute as -an argument. By default, the zero-based integer index of the object's -position in the ``ordering_list`` is synchronized with the ordering attribute: -index 0 will get position 0, index 1 position 1, etc. To start numbering at 1 -or some other integer, provide ``count_from=1``. + :class:`.OrderingList` only provides limited functionality when a primary + key column or unique column is the target of the sort. Operations + that are unsupported or are problematic include: -Ordering values are not limited to incrementing integers. Almost any scheme -can implemented by supplying a custom ``ordering_func`` that maps a Python list -index to any value you require. + * two entries must trade values. This is not supported directly in the + case of a primary key or unique constraint because it means at least + one row would need to be temporarily removed first, or changed to + a third, neutral value while the switch occurs. + * an entry must be deleted in order to make room for a new entry. + SQLAlchemy's unit of work performs all INSERTs before DELETEs within a + single flush. In the case of a primary key, it will trade + an INSERT/DELETE of the same primary key for an UPDATE statement in order + to lessen the impact of this limitation, however this does not take place + for a UNIQUE column. + A future feature will allow the "DELETE before INSERT" behavior to be + possible, allevating this limitation, though this feature will require + explicit configuration at the mapper level for sets of columns that + are to be handled in this way. +:func:`.ordering_list` takes the name of the related object's ordering +attribute as an argument. By default, the zero-based integer index of the +object's position in the :func:`.ordering_list` is synchronized with the +ordering attribute: index 0 will get position 0, index 1 position 1, etc. To +start numbering at 1 or some other integer, provide ``count_from=1``. """ -from sqlalchemy.orm.collections import collection -from sqlalchemy import util +from ..orm.collections import collection, collection_adapter +from .. import util -__all__ = [ 'ordering_list' ] +__all__ = ['ordering_list'] def ordering_list(attr, count_from=None, **kw): - """Prepares an OrderingList factory for use in mapper definitions. + """Prepares an :class:`OrderingList` factory for use in mapper definitions. - Returns an object suitable for use as an argument to a Mapper relationship's - ``collection_class`` option. Arguments are: + Returns an object suitable for use as an argument to a Mapper + relationship's ``collection_class`` option. e.g.:: - attr + from sqlalchemy.ext.orderinglist import ordering_list + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) + + :param attr: Name of the mapped attribute to use for storage and retrieval of ordering information - count_from (optional) + :param count_from: Set up an integer-based ordering, starting at ``count_from``. For example, ``ordering_list('pos', count_from=1)`` would create a 1-based list in SQL, storing the value in the 'pos' column. Ignored if ``ordering_func`` is supplied. - Passes along any keyword arguments to ``OrderingList`` constructor. + Additional arguments are passed to the :class:`.OrderingList` constructor. + """ kw = _unsugar_count_from(count_from=count_from, **kw) return lambda: OrderingList(attr, **kw) + # Ordering utility functions + + def count_from_0(index, collection): """Numbering function: consecutive integers starting at 0.""" return index + def count_from_1(index, collection): """Numbering function: consecutive integers starting at 1.""" return index + 1 + def count_from_n_factory(start): """Numbering function: consecutive integers starting at arbitrary start.""" @@ -139,8 +186,9 @@ def count_from_n_factory(start): pass return f + def _unsugar_count_from(**kw): - """Builds counting functions from keywrod arguments. + """Builds counting functions from keyword arguments. Keyword argument filter, prepares a simple ``ordering_func`` from a ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. @@ -156,12 +204,13 @@ def _unsugar_count_from(**kw): kw['ordering_func'] = count_from_n_factory(count_from) return kw + class OrderingList(list): """A custom list that manages position information for its children. - See the module and __init__ documentation for more details. The - ``ordering_list`` factory function is used to configure ``OrderingList`` - collections in ``mapper`` relationship definitions. + The :class:`.OrderingList` object is normally set up using the + :func:`.ordering_list` factory function, used in conjunction with + the :func:`.relationship` function. """ @@ -176,14 +225,14 @@ class OrderingList(list): This implementation relies on the list starting in the proper order, so be **sure** to put an ``order_by`` on your relationship. - ordering_attr + :param ordering_attr: Name of the attribute that stores the object's order in the relationship. - ordering_func - Optional. A function that maps the position in the Python list to a - value to store in the ``ordering_attr``. Values returned are - usually (but need not be!) integers. + :param ordering_func: Optional. A function that maps the position in + the Python list to a value to store in the + ``ordering_attr``. Values returned are usually (but need not be!) + integers. An ``ordering_func`` is called with two positional parameters: the index of the element in the list, and the list itself. @@ -194,7 +243,7 @@ class OrderingList(list): like stepped numbering, alphabetical and Fibonacci numbering, see the unit tests. - reorder_on_append + :param reorder_on_append: Default False. When appending an object with an existing (non-None) ordering value, that value will be left untouched unless ``reorder_on_append`` is true. This is an optimization to avoid a @@ -208,7 +257,7 @@ class OrderingList(list): making changes, any of whom happen to load this collection even in passing, all of the sessions would try to "clean up" the numbering in their commits, possibly causing all but one to fail with a - concurrent modification error. Spooky action at a distance. + concurrent modification error. Recommend leaving this with the default of False, and just call ``reorder()`` if you're doing ``append()`` operations with @@ -270,7 +319,10 @@ class OrderingList(list): def remove(self, entity): super(OrderingList, self).remove(entity) - self._reorder() + + adapter = collection_adapter(self) + if adapter and adapter._referenced_by_owner: + self._reorder() def pop(self, index=-1): entity = super(OrderingList, self).pop(index) @@ -286,8 +338,8 @@ class OrderingList(list): stop = index.stop or len(self) if stop < 0: stop += len(self) - - for i in xrange(start, stop, step): + + for i in range(start, stop, step): self.__setitem__(i, entity[i]) else: self._order_entity(index, entity, True) @@ -297,7 +349,6 @@ class OrderingList(list): super(OrderingList, self).__delitem__(index) self._reorder() - # Py2K def __setslice__(self, start, end, values): super(OrderingList, self).__setslice__(start, end, values) self._reorder() @@ -305,11 +356,25 @@ class OrderingList(list): def __delslice__(self, start, end): super(OrderingList, self).__delslice__(start, end) self._reorder() - # end Py2K - - for func_name, func in locals().items(): - if (util.callable(func) and func.func_name == func_name and - not func.__doc__ and hasattr(list, func_name)): + + def __reduce__(self): + return _reconstitute, (self.__class__, self.__dict__, list(self)) + + for func_name, func in list(locals().items()): + if (util.callable(func) and func.__name__ == func_name and + not func.__doc__ and hasattr(list, func_name)): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func + +def _reconstitute(cls, dict_, items): + """ Reconstitute an :class:`.OrderingList`. + + This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for + unpickling :class:`.OrderingList` objects. + + """ + obj = cls.__new__(cls) + obj.__dict__.update(dict_) + list.extend(obj, items) + return obj diff --git a/sqlalchemy/ext/serializer.py b/sqlalchemy/ext/serializer.py index 354f28c..2fbc62e 100644 --- a/sqlalchemy/ext/serializer.py +++ b/sqlalchemy/ext/serializer.py @@ -1,4 +1,11 @@ -"""Serializer/Deserializer objects for usage with SQLAlchemy query structures, +# ext/serializer.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Serializer/Deserializer objects for usage with SQLAlchemy query structures, allowing "contextual" deserialization. Any SQLAlchemy query structure, either based on sqlalchemy.sql.* @@ -12,82 +19,73 @@ Usage is nearly the same as that of the standard Python pickle module:: from sqlalchemy.ext.serializer import loads, dumps metadata = MetaData(bind=some_engine) Session = scoped_session(sessionmaker()) - + # ... define mappers - - query = Session.query(MyClass).filter(MyClass.somedata=='foo').order_by(MyClass.sortkey) - + + query = Session.query(MyClass). + filter(MyClass.somedata=='foo').order_by(MyClass.sortkey) + # pickle the query serialized = dumps(query) - + # unpickle. Pass in metadata + scoped_session query2 = loads(serialized, metadata, Session) - + print query2.all() -Similar restrictions as when using raw pickle apply; mapped classes must be +Similar restrictions as when using raw pickle apply; mapped classes must be themselves be pickleable, meaning they are importable from a module-level namespace. The serializer module is only appropriate for query structures. It is not needed for: -* instances of user-defined classes. These contain no references to engines, - sessions or expression constructs in the typical case and can be serialized directly. +* instances of user-defined classes. These contain no references to engines, + sessions or expression constructs in the typical case and can be serialized + directly. -* Table metadata that is to be loaded entirely from the serialized structure (i.e. is - not already declared in the application). Regular pickle.loads()/dumps() can - be used to fully dump any ``MetaData`` object, typically one which was reflected - from an existing database at some previous point in time. The serializer module - is specifically for the opposite case, where the Table metadata is already present - in memory. +* Table metadata that is to be loaded entirely from the serialized structure + (i.e. is not already declared in the application). Regular + pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object, + typically one which was reflected from an existing database at some previous + point in time. The serializer module is specifically for the opposite case, + where the Table metadata is already present in memory. """ -from sqlalchemy.orm import class_mapper, Query -from sqlalchemy.orm.session import Session -from sqlalchemy.orm.mapper import Mapper -from sqlalchemy.orm.attributes import QueryableAttribute -from sqlalchemy import Table, Column -from sqlalchemy.engine import Engine -from sqlalchemy.util import pickle +from ..orm import class_mapper +from ..orm.session import Session +from ..orm.mapper import Mapper +from ..orm.interfaces import MapperProperty +from ..orm.attributes import QueryableAttribute +from .. import Table, Column +from ..engine import Engine +from ..util import pickle, byte_buffer, b64encode, b64decode, text_type import re -import base64 -# Py3K -#from io import BytesIO as byte_buffer -# Py2K -from cStringIO import StringIO as byte_buffer -# end Py2K -# Py3K -#def b64encode(x): -# return base64.b64encode(x).decode('ascii') -#def b64decode(x): -# return base64.b64decode(x.encode('ascii')) -# Py2K -b64encode = base64.b64encode -b64decode = base64.b64decode -# end Py2K __all__ = ['Serializer', 'Deserializer', 'dumps', 'loads'] - def Serializer(*args, **kw): pickler = pickle.Pickler(*args, **kw) - + def persistent_id(obj): - #print "serializing:", repr(obj) + # print "serializing:", repr(obj) if isinstance(obj, QueryableAttribute): cls = obj.impl.class_ key = obj.impl.key id = "attribute:" + key + ":" + b64encode(pickle.dumps(cls)) elif isinstance(obj, Mapper) and not obj.non_primary: id = "mapper:" + b64encode(pickle.dumps(obj.class_)) + elif isinstance(obj, MapperProperty) and not obj.parent.non_primary: + id = "mapperprop:" + b64encode(pickle.dumps(obj.parent.class_)) + \ + ":" + obj.key elif isinstance(obj, Table): - id = "table:" + str(obj) + id = "table:" + text_type(obj.key) elif isinstance(obj, Column) and isinstance(obj.table, Table): - id = "column:" + str(obj.table) + ":" + obj.key + id = "column:" + \ + text_type(obj.table.key) + ":" + text_type(obj.key) elif isinstance(obj, Session): id = "session:" elif isinstance(obj, Engine): @@ -95,15 +93,17 @@ def Serializer(*args, **kw): else: return None return id - + pickler.persistent_id = persistent_id return pickler - -our_ids = re.compile(r'(mapper|table|column|session|attribute|engine):(.*)') + +our_ids = re.compile( + r'(mapperprop|mapper|table|column|session|attribute|engine):(.*)') + def Deserializer(file, metadata=None, scoped_session=None, engine=None): unpickler = pickle.Unpickler(file) - + def get_engine(): if engine: return engine @@ -113,9 +113,9 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): return metadata.bind else: return None - + def persistent_load(id): - m = our_ids.match(id) + m = our_ids.match(text_type(id)) if not m: return None else: @@ -127,6 +127,10 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): elif type_ == "mapper": cls = pickle.loads(b64decode(args)) return class_mapper(cls) + elif type_ == "mapperprop": + mapper, keyname = args.split(':') + cls = pickle.loads(b64decode(mapper)) + return class_mapper(cls).attrs[keyname] elif type_ == "table": return metadata.tables[args] elif type_ == "column": @@ -141,15 +145,15 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): unpickler.persistent_load = persistent_load return unpickler + def dumps(obj, protocol=0): buf = byte_buffer() pickler = Serializer(buf, protocol) pickler.dump(obj) return buf.getvalue() - + + def loads(data, metadata=None, scoped_session=None, engine=None): buf = byte_buffer(data) unpickler = Deserializer(buf, metadata, scoped_session, engine) return unpickler.load() - - diff --git a/sqlalchemy/interfaces.py b/sqlalchemy/interfaces.py index c2a267d..33f3cf1 100644 --- a/sqlalchemy/interfaces.py +++ b/sqlalchemy/interfaces.py @@ -1,31 +1,45 @@ -# interfaces.py +# sqlalchemy/interfaces.py +# Copyright (C) 2007-2017 the SQLAlchemy authors and contributors +# # Copyright (C) 2007 Jason Kirtland jek@discorporate.us # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Interfaces and abstract types.""" +"""Deprecated core event interfaces. + +This module is **deprecated** and is superseded by the +event system. + +""" + +from . import event, util class PoolListener(object): - """Hooks into the lifecycle of connections in a ``Pool``. + """Hooks into the lifecycle of connections in a :class:`.Pool`. + + .. note:: + + :class:`.PoolListener` is deprecated. Please + refer to :class:`.PoolEvents`. Usage:: - + class MyListener(PoolListener): def connect(self, dbapi_con, con_record): '''perform connect operations''' - # etc. - + # etc. + # create a new pool with a listener p = QueuePool(..., listeners=[MyListener()]) - + # add a listener after the fact p.add_listener(MyListener()) - + # usage with create_engine() e = create_engine("url://", listeners=[MyListener()]) - + All of the standard connection :class:`~sqlalchemy.pool.Pool` types can accept event listeners for key connection lifecycle events: creation, pool check-out and check-in. There are no events fired @@ -56,9 +70,28 @@ class PoolListener(object): internal event queues based on its capabilities. In terms of efficiency and function call overhead, you're much better off only providing implementations for the hooks you'll be using. - + """ + @classmethod + def _adapt_listener(cls, self, listener): + """Adapt a :class:`.PoolListener` to individual + :class:`event.Dispatch` events. + + """ + + listener = util.as_interface(listener, + methods=('connect', 'first_connect', + 'checkout', 'checkin')) + if hasattr(listener, 'connect'): + event.listen(self, 'connect', listener.connect) + if hasattr(listener, 'first_connect'): + event.listen(self, 'first_connect', listener.first_connect) + if hasattr(listener, 'checkout'): + event.listen(self, 'checkout', listener.checkout) + if hasattr(listener, 'checkin'): + event.listen(self, 'checkin', listener.checkin) + def connect(self, dbapi_con, con_record): """Called once for each new DB-API connection or Pool's ``creator()``. @@ -117,89 +150,163 @@ class PoolListener(object): """ + class ConnectionProxy(object): """Allows interception of statement execution by Connections. - + + .. note:: + + :class:`.ConnectionProxy` is deprecated. Please + refer to :class:`.ConnectionEvents`. + Either or both of the ``execute()`` and ``cursor_execute()`` may be implemented to intercept compiled statement and cursor level executions, e.g.:: - + class MyProxy(ConnectionProxy): - def execute(self, conn, execute, clauseelement, *multiparams, **params): + def execute(self, conn, execute, clauseelement, + *multiparams, **params): print "compiled statement:", clauseelement return execute(clauseelement, *multiparams, **params) - - def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): + + def cursor_execute(self, execute, cursor, statement, + parameters, context, executemany): print "raw statement:", statement return execute(cursor, statement, parameters, context) The ``execute`` argument is a function that will fulfill the default execution behavior for the operation. The signature illustrated in the example should be used. - + The proxy is installed into an :class:`~sqlalchemy.engine.Engine` via the ``proxy`` argument:: - + e = create_engine('someurl://', proxy=MyProxy()) - + """ + + @classmethod + def _adapt_listener(cls, self, listener): + + def adapt_execute(conn, clauseelement, multiparams, params): + + def execute_wrapper(clauseelement, *multiparams, **params): + return clauseelement, multiparams, params + + return listener.execute(conn, execute_wrapper, + clauseelement, *multiparams, + **params) + + event.listen(self, 'before_execute', adapt_execute) + + def adapt_cursor_execute(conn, cursor, statement, + parameters, context, executemany): + + def execute_wrapper( + cursor, + statement, + parameters, + context, + ): + return statement, parameters + + return listener.cursor_execute( + execute_wrapper, + cursor, + statement, + parameters, + context, + executemany, + ) + + event.listen(self, 'before_cursor_execute', adapt_cursor_execute) + + def do_nothing_callback(*arg, **kw): + pass + + def adapt_listener(fn): + + def go(conn, *arg, **kw): + fn(conn, do_nothing_callback, *arg, **kw) + + return util.update_wrapper(go, fn) + + event.listen(self, 'begin', adapt_listener(listener.begin)) + event.listen(self, 'rollback', + adapt_listener(listener.rollback)) + event.listen(self, 'commit', adapt_listener(listener.commit)) + event.listen(self, 'savepoint', + adapt_listener(listener.savepoint)) + event.listen(self, 'rollback_savepoint', + adapt_listener(listener.rollback_savepoint)) + event.listen(self, 'release_savepoint', + adapt_listener(listener.release_savepoint)) + event.listen(self, 'begin_twophase', + adapt_listener(listener.begin_twophase)) + event.listen(self, 'prepare_twophase', + adapt_listener(listener.prepare_twophase)) + event.listen(self, 'rollback_twophase', + adapt_listener(listener.rollback_twophase)) + event.listen(self, 'commit_twophase', + adapt_listener(listener.commit_twophase)) + def execute(self, conn, execute, clauseelement, *multiparams, **params): """Intercept high level execute() events.""" - + return execute(clauseelement, *multiparams, **params) - def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): + def cursor_execute(self, execute, cursor, statement, parameters, + context, executemany): """Intercept low-level cursor execute() events.""" - + return execute(cursor, statement, parameters, context) - + def begin(self, conn, begin): """Intercept begin() events.""" - + return begin() - + def rollback(self, conn, rollback): """Intercept rollback() events.""" - + return rollback() - + def commit(self, conn, commit): """Intercept commit() events.""" - + return commit() - + def savepoint(self, conn, savepoint, name=None): """Intercept savepoint() events.""" - + return savepoint(name=name) - + def rollback_savepoint(self, conn, rollback_savepoint, name, context): """Intercept rollback_savepoint() events.""" - + return rollback_savepoint(name, context) - + def release_savepoint(self, conn, release_savepoint, name, context): """Intercept release_savepoint() events.""" - + return release_savepoint(name, context) - + def begin_twophase(self, conn, begin_twophase, xid): """Intercept begin_twophase() events.""" - + return begin_twophase(xid) - + def prepare_twophase(self, conn, prepare_twophase, xid): """Intercept prepare_twophase() events.""" - + return prepare_twophase(xid) - + def rollback_twophase(self, conn, rollback_twophase, xid, is_prepared): """Intercept rollback_twophase() events.""" - + return rollback_twophase(xid, is_prepared) - + def commit_twophase(self, conn, commit_twophase, xid, is_prepared): """Intercept commit_twophase() events.""" - + return commit_twophase(xid, is_prepared) - diff --git a/sqlalchemy/log.py b/sqlalchemy/log.py index 49c779f..279538a 100644 --- a/sqlalchemy/log.py +++ b/sqlalchemy/log.py @@ -1,5 +1,7 @@ -# log.py - adapt python logging module to SQLAlchemy -# Copyright (C) 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# sqlalchemy/log.py +# Copyright (C) 2006-2017 the SQLAlchemy authors and contributors +# +# Includes alterations by Vinay Sajip vinay_sajip@yahoo.co.uk # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -10,92 +12,189 @@ Control of logging for SA can be performed from the regular python logging module. The regular dotted module namespace is used, starting at 'sqlalchemy'. For class-level logging, the class name is appended. -The "echo" keyword parameter which is available on SQLA ``Engine`` -and ``Pool`` objects corresponds to a logger specific to that +The "echo" keyword parameter, available on SQLA :class:`.Engine` +and :class:`.Pool` objects, corresponds to a logger specific to that instance only. -E.g.:: - - engine.echo = True - -is equivalent to:: - - import logging - logger = logging.getLogger('sqlalchemy.engine.Engine.%s' % hex(id(engine))) - logger.setLevel(logging.DEBUG) - """ import logging import sys -from sqlalchemy import util +# set initial level to WARN. This so that +# log statements don't occur in the absence of explicit +# logging being enabled for 'sqlalchemy'. rootlogger = logging.getLogger('sqlalchemy') if rootlogger.level == logging.NOTSET: rootlogger.setLevel(logging.WARN) -default_enabled = False -def default_logging(name): - global default_enabled - if logging.getLogger(name).getEffectiveLevel() < logging.WARN: - default_enabled = True - if not default_enabled: - default_enabled = True - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(logging.Formatter( - '%(asctime)s %(levelname)s %(name)s %(message)s')) - rootlogger.addHandler(handler) + +def _add_default_handler(logger): + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter( + '%(asctime)s %(levelname)s %(name)s %(message)s')) + logger.addHandler(handler) + _logged_classes = set() -def class_logger(cls, enable=False): + + +def class_logger(cls): logger = logging.getLogger(cls.__module__ + "." + cls.__name__) - if enable == 'debug': - logger.setLevel(logging.DEBUG) - elif enable == 'info': - logger.setLevel(logging.INFO) cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG) cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO) cls.logger = logger _logged_classes.add(cls) - + return cls + class Identified(object): - @util.memoized_property - def logging_name(self): - # limit the number of loggers by chopping off the hex(id). - # some novice users unfortunately create an unlimited number - # of Engines in their applications which would otherwise - # cause the app to run out of memory. - return "0x...%s" % hex(id(self))[-4:] + logging_name = None - -def instance_logger(instance, echoflag=None): - """create a logger for an instance that implements :class:`Identified`. - - Warning: this is an expensive call which also results in a permanent - increase in memory overhead for each call. Use only for - low-volume, long-time-spanning objects. - + def _should_log_debug(self): + return self.logger.isEnabledFor(logging.DEBUG) + + def _should_log_info(self): + return self.logger.isEnabledFor(logging.INFO) + + +class InstanceLogger(object): + """A logger adapter (wrapper) for :class:`.Identified` subclasses. + + This allows multiple instances (e.g. Engine or Pool instances) + to share a logger, but have its verbosity controlled on a + per-instance basis. + + The basic functionality is to return a logging level + which is based on an instance's echo setting. + + Default implementation is: + + 'debug' -> logging.DEBUG + True -> logging.INFO + False -> Effective level of underlying logger + (logging.WARNING by default) + None -> same as False """ - name = "%s.%s.%s" % (instance.__class__.__module__, - instance.__class__.__name__, instance.logging_name) - - if echoflag is not None: - l = logging.getLogger(name) - if echoflag == 'debug': - default_logging(name) - l.setLevel(logging.DEBUG) - elif echoflag is True: - default_logging(name) - l.setLevel(logging.INFO) - elif echoflag is False: - l.setLevel(logging.WARN) + # Map echo settings to logger levels + _echo_map = { + None: logging.NOTSET, + False: logging.NOTSET, + True: logging.INFO, + 'debug': logging.DEBUG, + } + + def __init__(self, echo, name): + self.echo = echo + self.logger = logging.getLogger(name) + + # if echo flag is enabled and no handlers, + # add a handler to the list + if self._echo_map[echo] <= logging.INFO \ + and not self.logger.handlers: + _add_default_handler(self.logger) + + # + # Boilerplate convenience methods + # + def debug(self, msg, *args, **kwargs): + """Delegate a debug call to the underlying logger.""" + + self.log(logging.DEBUG, msg, *args, **kwargs) + + def info(self, msg, *args, **kwargs): + """Delegate an info call to the underlying logger.""" + + self.log(logging.INFO, msg, *args, **kwargs) + + def warning(self, msg, *args, **kwargs): + """Delegate a warning call to the underlying logger.""" + + self.log(logging.WARNING, msg, *args, **kwargs) + + warn = warning + + def error(self, msg, *args, **kwargs): + """ + Delegate an error call to the underlying logger. + """ + self.log(logging.ERROR, msg, *args, **kwargs) + + def exception(self, msg, *args, **kwargs): + """Delegate an exception call to the underlying logger.""" + + kwargs["exc_info"] = 1 + self.log(logging.ERROR, msg, *args, **kwargs) + + def critical(self, msg, *args, **kwargs): + """Delegate a critical call to the underlying logger.""" + + self.log(logging.CRITICAL, msg, *args, **kwargs) + + def log(self, level, msg, *args, **kwargs): + """Delegate a log call to the underlying logger. + + The level here is determined by the echo + flag as well as that of the underlying logger, and + logger._log() is called directly. + + """ + + # inline the logic from isEnabledFor(), + # getEffectiveLevel(), to avoid overhead. + + if self.logger.manager.disable >= level: + return + + selected_level = self._echo_map[self.echo] + if selected_level == logging.NOTSET: + selected_level = self.logger.getEffectiveLevel() + + if level >= selected_level: + self.logger._log(level, msg, args, **kwargs) + + def isEnabledFor(self, level): + """Is this logger enabled for level 'level'?""" + + if self.logger.manager.disable >= level: + return False + return level >= self.getEffectiveLevel() + + def getEffectiveLevel(self): + """What's the effective level for this logger?""" + + level = self._echo_map[self.echo] + if level == logging.NOTSET: + level = self.logger.getEffectiveLevel() + return level + + +def instance_logger(instance, echoflag=None): + """create a logger for an instance that implements :class:`.Identified`.""" + + if instance.logging_name: + name = "%s.%s.%s" % (instance.__class__.__module__, + instance.__class__.__name__, + instance.logging_name) else: - l = logging.getLogger(name) - instance._should_log_debug = lambda: l.isEnabledFor(logging.DEBUG) - instance._should_log_info = lambda: l.isEnabledFor(logging.INFO) - return l + name = "%s.%s" % (instance.__class__.__module__, + instance.__class__.__name__) + + instance._echo = echoflag + + if echoflag in (False, None): + # if no echo setting or False, return a Logger directly, + # avoiding overhead of filtering + logger = logging.getLogger(name) + else: + # if a specified echo flag, return an EchoLogger, + # which checks the flag, overrides normal log + # levels by calling logger._log() + logger = InstanceLogger(echoflag, name) + + instance.logger = logger + class echo_property(object): __doc__ = """\ @@ -112,8 +211,7 @@ class echo_property(object): if instance is None: return self else: - return instance._should_log_debug() and 'debug' or \ - (instance._should_log_info() and True or False) + return instance._echo def __set__(self, instance, value): instance_logger(instance, echoflag=value) diff --git a/sqlalchemy/orm/__init__.py b/sqlalchemy/orm/__init__.py index 206c8d0..4491735 100644 --- a/sqlalchemy/orm/__init__.py +++ b/sqlalchemy/orm/__init__.py @@ -1,5 +1,6 @@ -# sqlalchemy/orm/__init__.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# orm/__init__.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -12,149 +13,77 @@ documentation for an overview of how this module is used. """ -from sqlalchemy.orm import exc -from sqlalchemy.orm.mapper import ( - Mapper, - _mapper_registry, - class_mapper, - ) -from sqlalchemy.orm.interfaces import ( - EXT_CONTINUE, - EXT_STOP, - ExtensionOption, - InstrumentationManager, - MapperExtension, - PropComparator, - SessionExtension, - AttributeExtension, - ) -from sqlalchemy.orm.util import ( - AliasedClass as aliased, - Validator, - join, - object_mapper, - outerjoin, - polymorphic_union, - with_parent, - ) -from sqlalchemy.orm.properties import ( - ColumnProperty, - ComparableProperty, - CompositeProperty, - RelationshipProperty, - PropertyLoader, - SynonymProperty, - ) -from sqlalchemy.orm import mapper as mapperlib -from sqlalchemy.orm.mapper import reconstructor, validates -from sqlalchemy.orm import strategies -from sqlalchemy.orm.query import AliasOption, Query -from sqlalchemy.sql import util as sql_util -from sqlalchemy.orm.session import Session as _Session -from sqlalchemy.orm.session import object_session, sessionmaker, make_transient -from sqlalchemy.orm.scoping import ScopedSession -from sqlalchemy import util as sa_util +from . import exc +from .mapper import ( + Mapper, + _mapper_registry, + class_mapper, + configure_mappers, + reconstructor, + validates +) +from .interfaces import ( + EXT_CONTINUE, + EXT_STOP, + PropComparator, +) +from .deprecated_interfaces import ( + MapperExtension, + SessionExtension, + AttributeExtension, +) +from .util import ( + aliased, + join, + object_mapper, + outerjoin, + polymorphic_union, + was_deleted, + with_parent, + with_polymorphic, +) +from .properties import ColumnProperty +from .relationships import RelationshipProperty +from .descriptor_props import ( + ComparableProperty, + CompositeProperty, + SynonymProperty, +) +from .relationships import ( + foreign, + remote, +) +from .session import ( + Session, + object_session, + sessionmaker, + make_transient, + make_transient_to_detached +) +from .scoping import ( + scoped_session +) +from . import mapper as mapperlib +from .query import AliasOption, Query, Bundle +from ..util.langhelpers import public_factory +from .. import util as _sa_util +from . import strategies as _strategies -__all__ = ( - 'EXT_CONTINUE', - 'EXT_STOP', - 'InstrumentationManager', - 'MapperExtension', - 'AttributeExtension', - 'Validator', - 'PropComparator', - 'Query', - 'aliased', - 'backref', - 'class_mapper', - 'clear_mappers', - 'column_property', - 'comparable_property', - 'compile_mappers', - 'composite', - 'contains_alias', - 'contains_eager', - 'create_session', - 'defer', - 'deferred', - 'dynamic_loader', - 'eagerload', - 'eagerload_all', - 'extension', - 'join', - 'joinedload', - 'joinedload_all', - 'lazyload', - 'mapper', - 'make_transient', - 'noload', - 'object_mapper', - 'object_session', - 'outerjoin', - 'polymorphic_union', - 'reconstructor', - 'relationship', - 'relation', - 'scoped_session', - 'sessionmaker', - 'subqueryload', - 'subqueryload_all', - 'synonym', - 'undefer', - 'undefer_group', - 'validates' - ) - - -def scoped_session(session_factory, scopefunc=None): - """Provides thread-local management of Sessions. - - This is a front-end function to - :class:`~sqlalchemy.orm.scoping.ScopedSession`. - - :param session_factory: a callable function that produces - :class:`Session` instances, such as :func:`sessionmaker` or - :func:`create_session`. - - :param scopefunc: optional, TODO - - :returns: an :class:`~sqlalchemy.orm.scoping.ScopedSession` instance - - Usage:: - - Session = scoped_session(sessionmaker(autoflush=True)) - - To instantiate a Session object which is part of the scoped context, - instantiate normally:: - - session = Session() - - Most session methods are available as classmethods from the scoped - session:: - - Session.commit() - Session.close() - - To map classes so that new instances are saved in the current Session - automatically, as well as to provide session-aware class attributes such - as "query", use the `mapper` classmethod from the scoped session:: - - mapper = Session.mapper - mapper(Class, table, ...) - - """ - return ScopedSession(session_factory, scopefunc=scopefunc) - def create_session(bind=None, **kwargs): - """Create a new :class:`~sqlalchemy.orm.session.Session`. + r"""Create a new :class:`.Session` + with no automation enabled by default. + + This function is used primarily for testing. The usual + route to :class:`.Session` creation is via its constructor + or the :func:`.sessionmaker` function. :param bind: optional, a single Connectable to use for all database access in the created :class:`~sqlalchemy.orm.session.Session`. :param \*\*kwargs: optional, passed through to the - :class:`Session` constructor. + :class:`.Session` constructor. :returns: an :class:`~sqlalchemy.orm.session.Session` instance @@ -175,738 +104,118 @@ def create_session(bind=None, **kwargs): kwargs.setdefault('autoflush', False) kwargs.setdefault('autocommit', True) kwargs.setdefault('expire_on_commit', False) - return _Session(bind=bind, **kwargs) + return Session(bind=bind, **kwargs) -def relationship(argument, secondary=None, **kwargs): - """Provide a relationship of a primary Mapper to a secondary Mapper. - - .. note:: This function is known as :func:`relation` in all versions - of SQLAlchemy prior to version 0.6beta2, including the 0.5 and 0.4 series. - :func:`~sqlalchemy.orm.relationship()` is only available starting with - SQLAlchemy 0.6beta2. The :func:`relation` name will remain available for - the foreseeable future in order to enable cross-compatibility. - - This corresponds to a parent-child or associative table relationship. The - constructed class is an instance of :class:`RelationshipProperty`. +relationship = public_factory(RelationshipProperty, ".orm.relationship") - A typical :func:`relationship`:: - - mapper(Parent, properties={ - 'children': relationship(Children) - }) - - :param argument: - a class or :class:`Mapper` instance, representing the target of - the relationship. - - :param secondary: - for a many-to-many relationship, specifies the intermediary - table. The *secondary* keyword argument should generally only - be used for a table that is not otherwise expressed in any class - mapping. In particular, using the Association Object Pattern is - generally mutually exclusive with the use of the *secondary* - keyword argument. - - :param backref: - indicates the string name of a property to be placed on the related - mapper's class that will handle this relationship in the other - direction. The other property will be created automatically - when the mappers are configured. Can also be passed as a - :func:`backref` object to control the configuration of the - new relationship. - - :param back_populates: - Takes a string name and has the same meaning as ``backref``, - except the complementing property is **not** created automatically, - and instead must be configured explicitly on the other mapper. The - complementing property should also indicate ``back_populates`` - to this relationship to ensure proper functioning. - - :param cascade: - a comma-separated list of cascade rules which determines how - Session operations should be "cascaded" from parent to child. - This defaults to ``False``, which means the default cascade - should be used. The default value is ``"save-update, merge"``. - - Available cascades are: - - * ``save-update`` - cascade the :meth:`~sqlalchemy.orm.session.Session.add` - operation. This cascade applies both to future and - past calls to :meth:`~sqlalchemy.orm.session.Session.add`, - meaning new items added to a collection or scalar relationship - get placed into the same session as that of the parent, and - also applies to items which have been removed from this - relationship but are still part of unflushed history. - - * ``merge`` - cascade the :meth:`~sqlalchemy.orm.session.Session.merge` - operation - - * ``expunge`` - cascade the :meth:`~sqlalchemy.orm.session.Session.expunge` - operation - - * ``delete`` - cascade the :meth:`~sqlalchemy.orm.session.Session.delete` - operation - - * ``delete-orphan`` - if an item of the child's type with no - parent is detected, mark it for deletion. Note that this - option prevents a pending item of the child's class from being - persisted without a parent present. - - * ``refresh-expire`` - cascade the :meth:`~sqlalchemy.orm.session.Session.expire` - and :meth:`~sqlalchemy.orm.session.Session.refresh` operations - - * ``all`` - shorthand for "save-update,merge, refresh-expire, - expunge, delete" - - :param collection_class: - a class or callable that returns a new list-holding object. will - be used in place of a plain list for storing elements. - - :param comparator_factory: - a class which extends :class:`RelationshipProperty.Comparator` which - provides custom SQL clause generation for comparison operations. - - :param extension: - an :class:`AttributeExtension` instance, or list of extensions, - which will be prepended to the list of attribute listeners for - the resulting descriptor placed on the class. These listeners - will receive append and set events before the operation - proceeds, and may be used to halt (via exception throw) or - change the value used in the operation. - - :param foreign_keys: - a list of columns which are to be used as "foreign key" columns. - this parameter should be used in conjunction with explicit - ``primaryjoin`` and ``secondaryjoin`` (if needed) arguments, and - the columns within the ``foreign_keys`` list should be present - within those join conditions. Normally, ``relationship()`` will - inspect the columns within the join conditions to determine - which columns are the "foreign key" columns, based on - information in the ``Table`` metadata. Use this argument when no - ForeignKey's are present in the join condition, or to override - the table-defined foreign keys. - - :param innerjoin=False: - when ``True``, joined eager loads will use an inner join to join - against related tables instead of an outer join. The purpose - of this option is strictly one of performance, as inner joins - generally perform better than outer joins. This flag can - be set to ``True`` when the relationship references an object - via many-to-one using local foreign keys that are not nullable, - or when the reference is one-to-one or a collection that is - guaranteed to have one or at least one entry. - - :param join_depth: - when non-``None``, an integer value indicating how many levels - deep "eager" loaders should join on a self-referring or cyclical - relationship. The number counts how many times the same Mapper - shall be present in the loading condition along a particular join - branch. When left at its default of ``None``, eager loaders - will stop chaining when they encounter a the same target mapper - which is already higher up in the chain. This option applies - both to joined- and subquery- eager loaders. - - :param lazy=('select'|'joined'|'subquery'|'noload'|'dynamic'): specifies - how the related items should be loaded. Values include: - - * 'select' - items should be loaded lazily when the property is first - accessed. - - * 'joined' - items should be loaded "eagerly" in the same query as - that of the parent, using a JOIN or LEFT OUTER JOIN. - - * 'subquery' - items should be loaded "eagerly" within the same - query as that of the parent, using a second SQL statement - which issues a JOIN to a subquery of the original - statement. - - * 'noload' - no loading should occur at any time. This is to - support "write-only" attributes, or attributes which are - populated in some manner specific to the application. - - * 'dynamic' - the attribute will return a pre-configured - :class:`~sqlalchemy.orm.query.Query` object for all read - operations, onto which further filtering operations can be - applied before iterating the results. The dynamic - collection supports a limited set of mutation operations, - allowing ``append()`` and ``remove()``. Changes to the - collection will not be visible until flushed - to the database, where it is then refetched upon iteration. - - * True - a synonym for 'select' - - * False - a synonyn for 'joined' - - * None - a synonym for 'noload' - - :param order_by: - indicates the ordering that should be applied when loading these - items. - - :param passive_deletes=False: - Indicates loading behavior during delete operations. - - A value of True indicates that unloaded child items should not - be loaded during a delete operation on the parent. Normally, - when a parent item is deleted, all child items are loaded so - that they can either be marked as deleted, or have their - foreign key to the parent set to NULL. Marking this flag as - True usually implies an ON DELETE rule is in - place which will handle updating/deleting child rows on the - database side. - - Additionally, setting the flag to the string value 'all' will - disable the "nulling out" of the child foreign keys, when there - is no delete or delete-orphan cascade enabled. This is - typically used when a triggering or error raise scenario is in - place on the database side. Note that the foreign key - attributes on in-session child objects will not be changed - after a flush occurs so this is a very special use-case - setting. - - :param passive_updates=True: - Indicates loading and INSERT/UPDATE/DELETE behavior when the - source of a foreign key value changes (i.e. an "on update" - cascade), which are typically the primary key columns of the - source row. - - When True, it is assumed that ON UPDATE CASCADE is configured on - the foreign key in the database, and that the database will - handle propagation of an UPDATE from a source column to - dependent rows. Note that with databases which enforce - referential integrity (i.e. PostgreSQL, MySQL with InnoDB tables), - ON UPDATE CASCADE is required for this operation. The - relationship() will update the value of the attribute on related - items which are locally present in the session during a flush. - - When False, it is assumed that the database does not enforce - referential integrity and will not be issuing its own CASCADE - operation for an update. The relationship() will issue the - appropriate UPDATE statements to the database in response to the - change of a referenced key, and items locally present in the - session during a flush will also be refreshed. - - This flag should probably be set to False if primary key changes - are expected and the database in use doesn't support CASCADE - (i.e. SQLite, MySQL MyISAM tables). - - Also see the passive_updates flag on ``mapper()``. - - A future SQLAlchemy release will provide a "detect" feature for - this flag. - - :param post_update: - this indicates that the relationship should be handled by a - second UPDATE statement after an INSERT or before a - DELETE. Currently, it also will issue an UPDATE after the - instance was UPDATEd as well, although this technically should - be improved. This flag is used to handle saving bi-directional - dependencies between two individual rows (i.e. each row - references the other), where it would otherwise be impossible to - INSERT or DELETE both rows fully since one row exists before the - other. Use this flag when a particular mapping arrangement will - incur two rows that are dependent on each other, such as a table - that has a one-to-many relationship to a set of child rows, and - also has a column that references a single child row within that - list (i.e. both tables contain a foreign key to each other). If - a ``flush()`` operation returns an error that a "cyclical - dependency" was detected, this is a cue that you might want to - use ``post_update`` to "break" the cycle. - - :param primaryjoin: - a ColumnElement (i.e. WHERE criterion) that will be used as the primary - join of this child object against the parent object, or in a - many-to-many relationship the join of the primary object to the - association table. By default, this value is computed based on the - foreign key relationships of the parent and child tables (or association - table). - - :param remote_side: - used for self-referential relationships, indicates the column or - list of columns that form the "remote side" of the relationship. - - :param secondaryjoin: - a ColumnElement (i.e. WHERE criterion) that will be used as the join of - an association table to the child object. By default, this value is - computed based on the foreign key relationships of the association and - child tables. - - :param single_parent=(True|False): - when True, installs a validator which will prevent objects - from being associated with more than one parent at a time. - This is used for many-to-one or many-to-many relationships that - should be treated either as one-to-one or one-to-many. Its - usage is optional unless delete-orphan cascade is also - set on this relationship(), in which case its required (new in 0.5.2). - - :param uselist=(True|False): - a boolean that indicates if this property should be loaded as a - list or a scalar. In most cases, this value is determined - automatically by ``relationship()``, based on the type and direction - of the relationship - one to many forms a list, many to one - forms a scalar, many to many is a list. If a scalar is desired - where normally a list would be present, such as a bi-directional - one-to-one relationship, set uselist to False. - - :param viewonly=False: - when set to True, the relationship is used only for loading objects - within the relationship, and has no effect on the unit-of-work - flush process. Relationships with viewonly can specify any kind of - join conditions to provide additional views of related objects - onto a parent object. Note that the functionality of a viewonly - relationship has its limits - complicated join conditions may - not compile into eager or lazy loaders properly. If this is the - case, use an alternative method. - - """ - return RelationshipProperty(argument, secondary=secondary, **kwargs) def relation(*arg, **kw): """A synonym for :func:`relationship`.""" - + return relationship(*arg, **kw) - -def dynamic_loader(argument, secondary=None, primaryjoin=None, - secondaryjoin=None, foreign_keys=None, backref=None, - post_update=False, cascade=False, remote_side=None, - enable_typechecks=True, passive_deletes=False, - order_by=None, comparator_factory=None, query_class=None): + + +def dynamic_loader(argument, **kw): """Construct a dynamically-loading mapper property. - This property is similar to :func:`relationship`, except read - operations return an active :class:`Query` object which reads from - the database when accessed. Items may be appended to the - attribute via ``append()``, or removed via ``remove()``; changes - will be persisted to the database during a :meth:`Sesion.flush`. - However, no other Python list or collection mutation operations - are available. + This is essentially the same as + using the ``lazy='dynamic'`` argument with :func:`relationship`:: - A subset of arguments available to :func:`relationship` are available - here. + dynamic_loader(SomeClass) - :param argument: - a class or :class:`Mapper` instance, representing the target of - the relationship. + # is the same as - :param secondary: - for a many-to-many relationship, specifies the intermediary - table. The *secondary* keyword argument should generally only - be used for a table that is not otherwise expressed in any class - mapping. In particular, using the Association Object Pattern is - generally mutually exclusive with the use of the *secondary* - keyword argument. + relationship(SomeClass, lazy="dynamic") - :param query_class: - Optional, a custom Query subclass to be used as the basis for - dynamic collection. + See the section :ref:`dynamic_relationship` for more details + on dynamic loading. """ - from sqlalchemy.orm.dynamic import DynaLoader + kw['lazy'] = 'dynamic' + return relationship(argument, **kw) - return RelationshipProperty( - argument, secondary=secondary, primaryjoin=primaryjoin, - secondaryjoin=secondaryjoin, foreign_keys=foreign_keys, backref=backref, - post_update=post_update, cascade=cascade, remote_side=remote_side, - enable_typechecks=enable_typechecks, passive_deletes=passive_deletes, - order_by=order_by, comparator_factory=comparator_factory, - strategy_class=DynaLoader, query_class=query_class) -def column_property(*args, **kwargs): - """Provide a column-level property for use with a Mapper. - - Column-based properties can normally be applied to the mapper's - ``properties`` dictionary using the ``schema.Column`` element directly. - Use this function when the given column is not directly present within the - mapper's selectable; examples include SQL expressions, functions, and - scalar SELECT queries. - - Columns that aren't present in the mapper's selectable won't be persisted - by the mapper and are effectively "read-only" attributes. - - \*cols - list of Column objects to be mapped. - - comparator_factory - a class which extends ``sqlalchemy.orm.properties.ColumnProperty.Comparator`` - which provides custom SQL clause generation for comparison operations. - - group - a group name for this property when marked as deferred. - - deferred - when True, the column property is "deferred", meaning that - it does not load immediately, and is instead loaded when the - attribute is first accessed on an instance. See also - :func:`~sqlalchemy.orm.deferred`. - - extension - an :class:`~sqlalchemy.orm.interfaces.AttributeExtension` instance, - or list of extensions, which will be prepended to the list of - attribute listeners for the resulting descriptor placed on the class. - These listeners will receive append and set events before the - operation proceeds, and may be used to halt (via exception throw) - or change the value used in the operation. - - """ - - return ColumnProperty(*args, **kwargs) - -def composite(class_, *cols, **kwargs): - """Return a composite column-based property for use with a Mapper. - - This is very much like a column-based property except the given class is - used to represent "composite" values composed of one or more columns. - - The class must implement a constructor with positional arguments matching - the order of columns supplied here, as well as a __composite_values__() - method which returns values in the same order. - - A simple example is representing separate two columns in a table as a - single, first-class "Point" object:: - - class Point(object): - def __init__(self, x, y): - self.x = x - self.y = y - def __composite_values__(self): - return self.x, self.y - def __eq__(self, other): - return other is not None and self.x == other.x and self.y == other.y - - # and then in the mapping: - ... composite(Point, mytable.c.x, mytable.c.y) ... - - The composite object may have its attributes populated based on the names - of the mapped columns. To override the way internal state is set, - additionally implement ``__set_composite_values__``:: - - class Point(object): - def __init__(self, x, y): - self.some_x = x - self.some_y = y - def __composite_values__(self): - return self.some_x, self.some_y - def __set_composite_values__(self, x, y): - self.some_x = x - self.some_y = y - def __eq__(self, other): - return other is not None and self.some_x == other.x and self.some_y == other.y - - Arguments are: - - class\_ - The "composite type" class. - - \*cols - List of Column objects to be mapped. - - group - A group name for this property when marked as deferred. - - deferred - When True, the column property is "deferred", meaning that it does not - load immediately, and is instead loaded when the attribute is first - accessed on an instance. See also :func:`~sqlalchemy.orm.deferred`. - - comparator_factory - a class which extends ``sqlalchemy.orm.properties.CompositeProperty.Comparator`` - which provides custom SQL clause generation for comparison operations. - - extension - an :class:`~sqlalchemy.orm.interfaces.AttributeExtension` instance, - or list of extensions, which will be prepended to the list of - attribute listeners for the resulting descriptor placed on the class. - These listeners will receive append and set events before the - operation proceeds, and may be used to halt (via exception throw) - or change the value used in the operation. - - """ - return CompositeProperty(class_, *cols, **kwargs) +column_property = public_factory(ColumnProperty, ".orm.column_property") +composite = public_factory(CompositeProperty, ".orm.composite") def backref(name, **kwargs): - """Create a back reference with explicit arguments, which are the same - arguments one can send to ``relationship()``. + """Create a back reference with explicit keyword arguments, which are the + same arguments one can send to :func:`relationship`. - Used with the `backref` keyword argument to ``relationship()`` in - place of a string argument. + Used with the ``backref`` keyword argument to :func:`relationship` in + place of a string argument, e.g.:: + + 'items':relationship( + SomeItem, backref=backref('parent', lazy='subquery')) + + .. seealso:: + + :ref:`relationships_backref` """ + return (name, kwargs) -def deferred(*columns, **kwargs): - """Return a ``DeferredColumnProperty``, which indicates this - object attributes should only be loaded from its corresponding - table column when first accessed. - Used with the `properties` dictionary sent to ``mapper()``. +def deferred(*columns, **kw): + r"""Indicate a column-based mapped attribute that by default will + not load unless accessed. + + :param \*columns: columns to be mapped. This is typically a single + :class:`.Column` object, however a collection is supported in order + to support multiple columns mapped under the same attribute. + + :param \**kw: additional keyword arguments passed to + :class:`.ColumnProperty`. + + .. seealso:: + + :ref:`deferred` """ - return ColumnProperty(deferred=True, *columns, **kwargs) - -def mapper(class_, local_table=None, *args, **params): - """Return a new :class:`~sqlalchemy.orm.Mapper` object. - - :param class\_: The class to be mapped. - - :param local_table: The table to which the class is mapped, or None if this mapper - inherits from another mapper using concrete table inheritance. - - :param always_refresh: If True, all query operations for this mapped class will overwrite all - data within object instances that already exist within the session, - erasing any in-memory changes with whatever information was loaded - from the database. Usage of this flag is highly discouraged; as an - alternative, see the method `populate_existing()` on - :class:`~sqlalchemy.orm.query.Query`. - - :param allow_null_pks: This flag is deprecated - this is stated as allow_partial_pks - which defaults to True. - - :param allow_partial_pks: Defaults to True. Indicates that a composite primary key with - some NULL values should be considered as possibly existing - within the database. This affects whether a mapper will assign - an incoming row to an existing identity, as well as if - session.merge() will check the database first for a particular - primary key value. A "partial primary key" can occur if one - has mapped to an OUTER JOIN, for example. - - :param batch: Indicates that save operations of multiple entities can be batched - together for efficiency. setting to False indicates that an instance - will be fully saved before saving the next instance, which includes - inserting/updating all table rows corresponding to the entity as well - as calling all ``MapperExtension`` methods corresponding to the save - operation. - - :param column_prefix: A string which will be prepended to the `key` name of all Columns when - creating column-based properties from the given Table. Does not - affect explicitly specified column-based properties - - :param concrete: If True, indicates this mapper should use concrete table inheritance - with its parent mapper. - - :param exclude_properties: A list of properties not to map. Columns present in the mapped table - and present in this list will not be automatically converted into - properties. Note that neither this option nor include_properties will - allow an end-run around Python inheritance. If mapped class ``B`` - inherits from mapped class ``A``, no combination of includes or - excludes will allow ``B`` to have fewer properties than its - superclass, ``A``. + return ColumnProperty(deferred=True, *columns, **kw) - :param extension: A :class:`~sqlalchemy.orm.interfaces.MapperExtension` instance or list of - :class:`~sqlalchemy.orm.interfaces.MapperExtension` instances which will be applied to all - operations by this :class:`~sqlalchemy.orm.mapper.Mapper`. +mapper = public_factory(Mapper, ".orm.mapper") - :param include_properties: An inclusive list of properties to map. Columns present in the mapped - table but not present in this list will not be automatically converted - into properties. +synonym = public_factory(SynonymProperty, ".orm.synonym") - :param inherits: Another :class:`~sqlalchemy.orm.Mapper` for which - this :class:`~sqlalchemy.orm.Mapper` will have an inheritance - relationship with. +comparable_property = public_factory(ComparableProperty, + ".orm.comparable_property") - :param inherit_condition: For joined table inheritance, a SQL expression (constructed - ``ClauseElement``) which will define how the two tables are joined; - defaults to a natural join between the two tables. - - :param inherit_foreign_keys: When inherit_condition is used and the condition contains no - ForeignKey columns, specify the "foreign" columns of the join - condition in this list. else leave as None. - - :param non_primary: Construct a ``Mapper`` that will define only the selection of - instances, not their persistence. Any number of non_primary mappers - may be created for a particular class. - - :param order_by: A single ``Column`` or list of ``Columns`` for which - selection operations should use as the default ordering for entities. - Defaults to the OID/ROWID of the table if any, or the first primary - key column of the table. - - :param passive_updates: Indicates UPDATE behavior of foreign keys when a primary key changes - on a joined-table inheritance or other joined table mapping. - - When True, it is assumed that ON UPDATE CASCADE is configured on - the foreign key in the database, and that the database will - handle propagation of an UPDATE from a source column to - dependent rows. Note that with databases which enforce - referential integrity (i.e. PostgreSQL, MySQL with InnoDB tables), - ON UPDATE CASCADE is required for this operation. The - relationship() will update the value of the attribute on related - items which are locally present in the session during a flush. - - When False, it is assumed that the database does not enforce - referential integrity and will not be issuing its own CASCADE - operation for an update. The relationship() will issue the - appropriate UPDATE statements to the database in response to the - change of a referenced key, and items locally present in the - session during a flush will also be refreshed. - - This flag should probably be set to False if primary key changes - are expected and the database in use doesn't support CASCADE - (i.e. SQLite, MySQL MyISAM tables). - - Also see the passive_updates flag on :func:`relationship()`. - - A future SQLAlchemy release will provide a "detect" feature for - this flag. - - :param polymorphic_on: Used with mappers in an inheritance relationship, a ``Column`` which - will identify the class/mapper combination to be used with a - particular row. Requires the ``polymorphic_identity`` value to be set - for all mappers in the inheritance hierarchy. The column specified by - ``polymorphic_on`` is usually a column that resides directly within - the base mapper's mapped table; alternatively, it may be a column that - is only present within the portion of the - ``with_polymorphic`` argument. - - :param polymorphic_identity: A value which will be stored in the Column denoted by polymorphic_on, - corresponding to the *class identity* of this mapper. - - :param properties: A dictionary mapping the string names of object attributes to - ``MapperProperty`` instances, which define the persistence behavior of - that attribute. Note that the columns in the mapped table are - automatically converted into ``ColumnProperty`` instances based on the - `key` property of each ``Column`` (although they can be overridden - using this dictionary). - - :param primary_key: A list of ``Column`` objects which define the *primary key* to be used - against this mapper's selectable unit. This is normally simply the - primary key of the `local_table`, but can be overridden here. - - :param version_id_col: A ``Column`` which must have an integer type that will be used to keep - a running *version id* of mapped entities in the database. this is - used during save operations to ensure that no other thread or process - has updated the instance during the lifetime of the entity, else a - ``ConcurrentModificationError`` exception is thrown. - - :param version_id_generator: A callable which defines the algorithm used to generate new version - ids. Defaults to an integer generator. Can be replaced with one that - generates timestamps, uuids, etc. e.g.:: - - import uuid - - mapper(Cls, table, - version_id_col=table.c.version_uuid, - version_id_generator=lambda version:uuid.uuid4().hex - ) - - The callable receives the current version identifier as its - single argument. - - :param with_polymorphic: A tuple in the form ``(, )`` indicating the - default style of "polymorphic" loading, that is, which tables are - queried at once. is any single or list of mappers and/or - classes indicating the inherited classes that should be loaded at - once. The special value ``'*'`` may be used to indicate all descending - classes should be loaded immediately. The second tuple argument - indicates a selectable that will be used to query for - multiple classes. Normally, it is left as None, in which case this - mapper will form an outer join from the base mapper's table to that of - all desired sub-mappers. When specified, it provides the selectable - to be used for polymorphic loading. When with_polymorphic includes - mappers which load from a "concrete" inheriting table, the - argument is required, since it usually requires more - complex UNION queries. - - - """ - return Mapper(class_, local_table, *args, **params) - -def synonym(name, map_column=False, descriptor=None, comparator_factory=None): - """Set up `name` as a synonym to another mapped property. - - Used with the ``properties`` dictionary sent to :func:`~sqlalchemy.orm.mapper`. - - Any existing attributes on the class which map the key name sent - to the ``properties`` dictionary will be used by the synonym to provide - instance-attribute behavior (that is, any Python property object, provided - by the ``property`` builtin or providing a ``__get__()``, ``__set__()`` - and ``__del__()`` method). If no name exists for the key, the - ``synonym()`` creates a default getter/setter object automatically and - applies it to the class. - - `name` refers to the name of the existing mapped property, which can be - any other ``MapperProperty`` including column-based properties and - relationships. - - If `map_column` is ``True``, an additional ``ColumnProperty`` is created - on the mapper automatically, using the synonym's name as the keyname of - the property, and the keyname of this ``synonym()`` as the name of the - column to map. For example, if a table has a column named ``status``:: - - class MyClass(object): - def _get_status(self): - return self._status - def _set_status(self, value): - self._status = value - status = property(_get_status, _set_status) - - mapper(MyClass, sometable, properties={ - "status":synonym("_status", map_column=True) - }) - - The column named ``status`` will be mapped to the attribute named - ``_status``, and the ``status`` attribute on ``MyClass`` will be used to - proxy access to the column-based attribute. - - """ - return SynonymProperty(name, map_column=map_column, descriptor=descriptor, comparator_factory=comparator_factory) - -def comparable_property(comparator_factory, descriptor=None): - """Provide query semantics for an unmanaged attribute. - - Allows a regular Python @property (descriptor) to be used in Queries and - SQL constructs like a managed attribute. comparable_property wraps a - descriptor with a proxy that directs operator overrides such as == - (__eq__) to the supplied comparator but proxies everything else through to - the original descriptor:: - - class MyClass(object): - @property - def myprop(self): - return 'foo' - - class MyComparator(sqlalchemy.orm.interfaces.PropComparator): - def __eq__(self, other): - .... - - mapper(MyClass, mytable, properties=dict( - 'myprop': comparable_property(MyComparator))) - - Used with the ``properties`` dictionary sent to :func:`~sqlalchemy.orm.mapper`. - - comparator_factory - A PropComparator subclass or factory that defines operator behavior - for this property. - - descriptor - Optional when used in a ``properties={}`` declaration. The Python - descriptor or property to layer comparison behavior on top of. - - The like-named descriptor will be automatically retreived from the - mapped class if left blank in a ``properties`` declaration. - - """ - return ComparableProperty(comparator_factory, descriptor) - +@_sa_util.deprecated("0.7", message=":func:`.compile_mappers` " + "is renamed to :func:`.configure_mappers`") def compile_mappers(): - """Compile all mappers that have been defined. - - This is equivalent to calling ``compile()`` on any individual mapper. + """Initialize the inter-mapper relationships of all mappers that have + been defined. """ - for m in list(_mapper_registry): - m.compile() + configure_mappers() + def clear_mappers(): - """Remove all mappers that have been created thus far. + """Remove all mappers from all classes. - The mapped classes will return to their initial "unmapped" state and can - be re-mapped with new mappers. + This function removes all instrumentation from classes and disposes + of their associated mappers. Once called, the classes are unmapped + and can be later re-mapped with new mappers. + + :func:`.clear_mappers` is *not* for normal use, as there is literally no + valid usage for it outside of very specific testing scenarios. Normally, + mappers are permanent structural components of user-defined classes, and + are never discarded independently of their class. If a mapped class + itself is garbage collected, its mapper is automatically disposed of as + well. As such, :func:`.clear_mappers` is only for usage in test suites + that re-use the same classes with different mappings, which is itself an + extremely rare use case - the only such use case is in fact SQLAlchemy's + own test suite, and possibly the test suites of other ORM extension + libraries which intend to test various combinations of mapper construction + upon a fixed set of classes. """ - mapperlib._COMPILE_MUTEX.acquire() + mapperlib._CONFIGURE_MUTEX.acquire() try: while _mapper_registry: try: @@ -916,261 +225,52 @@ def clear_mappers(): except KeyError: pass finally: - mapperlib._COMPILE_MUTEX.release() + mapperlib._CONFIGURE_MUTEX.release() -def extension(ext): - """Return a ``MapperOption`` that will insert the given - ``MapperExtension`` to the beginning of the list of extensions - that will be called in the context of the ``Query``. +from . import strategy_options - Used with :meth:`~sqlalchemy.orm.query.Query.options`. +joinedload = strategy_options.joinedload._unbound_fn +joinedload_all = strategy_options.joinedload._unbound_all_fn +contains_eager = strategy_options.contains_eager._unbound_fn +defer = strategy_options.defer._unbound_fn +undefer = strategy_options.undefer._unbound_fn +undefer_group = strategy_options.undefer_group._unbound_fn +load_only = strategy_options.load_only._unbound_fn +lazyload = strategy_options.lazyload._unbound_fn +lazyload_all = strategy_options.lazyload_all._unbound_all_fn +subqueryload = strategy_options.subqueryload._unbound_fn +subqueryload_all = strategy_options.subqueryload_all._unbound_all_fn +immediateload = strategy_options.immediateload._unbound_fn +noload = strategy_options.noload._unbound_fn +raiseload = strategy_options.raiseload._unbound_fn +defaultload = strategy_options.defaultload._unbound_fn - """ - return ExtensionOption(ext) +from .strategy_options import Load -@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated') -def joinedload(*keys, **kw): - """Return a ``MapperOption`` that will convert the property of the given - name into an joined eager load. - - .. note:: This function is known as :func:`eagerload` in all versions - of SQLAlchemy prior to version 0.6beta3, including the 0.5 and 0.4 series. - :func:`eagerload` will remain available for - the foreseeable future in order to enable cross-compatibility. - - Used with :meth:`~sqlalchemy.orm.query.Query.options`. - - examples:: - - # joined-load the "orders" colleciton on "User" - query(User).options(joinedload(User.orders)) - - # joined-load the "keywords" collection on each "Item", - # but not the "items" collection on "Order" - those - # remain lazily loaded. - query(Order).options(joinedload(Order.items, Item.keywords)) - - # to joined-load across both, use joinedload_all() - query(Order).options(joinedload_all(Order.items, Item.keywords)) - - :func:`joinedload` also accepts a keyword argument `innerjoin=True` which - indicates using an inner join instead of an outer:: - - query(Order).options(joinedload(Order.user, innerjoin=True)) - - Note that the join created by :func:`joinedload` is aliased such that - no other aspects of the query will affect what it loads. To use joined eager - loading with a join that is constructed manually using :meth:`~sqlalchemy.orm.query.Query.join` - or :func:`~sqlalchemy.orm.join`, see :func:`contains_eager`. - - See also: :func:`subqueryload`, :func:`lazyload` - - """ - innerjoin = kw.pop('innerjoin', None) - if innerjoin is not None: - return ( - strategies.EagerLazyOption(keys, lazy='joined'), - strategies.EagerJoinOption(keys, innerjoin) - ) - else: - return strategies.EagerLazyOption(keys, lazy='joined') - -@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated') -def joinedload_all(*keys, **kw): - """Return a ``MapperOption`` that will convert all properties along the - given dot-separated path into an joined eager load. - - .. note:: This function is known as :func:`eagerload_all` in all versions - of SQLAlchemy prior to version 0.6beta3, including the 0.5 and 0.4 series. - :func:`eagerload_all` will remain available for - the foreseeable future in order to enable cross-compatibility. - - Used with :meth:`~sqlalchemy.orm.query.Query.options`. - - For example:: - - query.options(joinedload_all('orders.items.keywords'))... - - will set all of 'orders', 'orders.items', and 'orders.items.keywords' to - load in one joined eager load. - - Individual descriptors are accepted as arguments as well:: - - query.options(joinedload_all(User.orders, Order.items, Item.keywords)) - - The keyword arguments accept a flag `innerjoin=True|False` which will - override the value of the `innerjoin` flag specified on the relationship(). - - See also: :func:`subqueryload_all`, :func:`lazyload` - - """ - innerjoin = kw.pop('innerjoin', None) - if innerjoin is not None: - return ( - strategies.EagerLazyOption(keys, lazy='joined', chained=True), - strategies.EagerJoinOption(keys, innerjoin, chained=True) - ) - else: - return strategies.EagerLazyOption(keys, lazy='joined', chained=True) def eagerload(*args, **kwargs): """A synonym for :func:`joinedload()`.""" return joinedload(*args, **kwargs) - + + def eagerload_all(*args, **kwargs): """A synonym for :func:`joinedload_all()`""" return joinedload_all(*args, **kwargs) - -def subqueryload(*keys): - """Return a ``MapperOption`` that will convert the property - of the given name into an subquery eager load. - .. note:: This function is new as of SQLAlchemy version 0.6beta3. - Used with :meth:`~sqlalchemy.orm.query.Query.options`. +contains_alias = public_factory(AliasOption, ".orm.contains_alias") - examples:: - - # subquery-load the "orders" colleciton on "User" - query(User).options(subqueryload(User.orders)) - - # subquery-load the "keywords" collection on each "Item", - # but not the "items" collection on "Order" - those - # remain lazily loaded. - query(Order).options(subqueryload(Order.items, Item.keywords)) - # to subquery-load across both, use subqueryload_all() - query(Order).options(subqueryload_all(Order.items, Item.keywords)) +def __go(lcls): + global __all__ + from .. import util as sa_util + from . import dynamic + from . import events + import inspect as _inspect - See also: :func:`joinedload`, :func:`lazyload` - - """ - return strategies.EagerLazyOption(keys, lazy="subquery") + __all__ = sorted(name for name, obj in lcls.items() + if not (name.startswith('_') or _inspect.ismodule(obj))) -def subqueryload_all(*keys): - """Return a ``MapperOption`` that will convert all properties along the - given dot-separated path into a subquery eager load. + _sa_util.dependencies.resolve_all("sqlalchemy.orm") - .. note:: This function is new as of SQLAlchemy version 0.6beta3. - - Used with :meth:`~sqlalchemy.orm.query.Query.options`. - - For example:: - - query.options(subqueryload_all('orders.items.keywords'))... - - will set all of 'orders', 'orders.items', and 'orders.items.keywords' to - load in one subquery eager load. - - Individual descriptors are accepted as arguments as well:: - - query.options(subqueryload_all(User.orders, Order.items, Item.keywords)) - - See also: :func:`joinedload_all`, :func:`lazyload` - - """ - return strategies.EagerLazyOption(keys, lazy="subquery", chained=True) - -@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated') -def lazyload(*keys): - """Return a ``MapperOption`` that will convert the property of the given - name into a lazy load. - - Used with :meth:`~sqlalchemy.orm.query.Query.options`. - - See also: :func:`eagerload`, :func:`subqueryload` - - """ - return strategies.EagerLazyOption(keys, lazy=True) - -def noload(*keys): - """Return a ``MapperOption`` that will convert the property of the - given name into a non-load. - - Used with :meth:`~sqlalchemy.orm.query.Query.options`. - - See also: :func:`lazyload`, :func:`eagerload`, :func:`subqueryload` - - """ - return strategies.EagerLazyOption(keys, lazy=None) - -def contains_alias(alias): - """Return a ``MapperOption`` that will indicate to the query that - the main table has been aliased. - - `alias` is the string name or ``Alias`` object representing the - alias. - - """ - return AliasOption(alias) - -@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated') -def contains_eager(*keys, **kwargs): - """Return a ``MapperOption`` that will indicate to the query that - the given attribute should be eagerly loaded from columns currently - in the query. - - Used with :meth:`~sqlalchemy.orm.query.Query.options`. - - The option is used in conjunction with an explicit join that loads - the desired rows, i.e.:: - - sess.query(Order).\\ - join(Order.user).\\ - options(contains_eager(Order.user)) - - The above query would join from the ``Order`` entity to its related - ``User`` entity, and the returned ``Order`` objects would have the - ``Order.user`` attribute pre-populated. - - :func:`contains_eager` also accepts an `alias` argument, which - is the string name of an alias, an :func:`~sqlalchemy.sql.expression.alias` - construct, or an :func:`~sqlalchemy.orm.aliased` construct. Use this - when the eagerly-loaded rows are to come from an aliased table:: - - user_alias = aliased(User) - sess.query(Order).\\ - join((user_alias, Order.user)).\\ - options(contains_eager(Order.user, alias=user_alias)) - - See also :func:`eagerload` for the "automatic" version of this - functionality. - - """ - alias = kwargs.pop('alias', None) - if kwargs: - raise exceptions.ArgumentError("Invalid kwargs for contains_eager: %r" % kwargs.keys()) - - return ( - strategies.EagerLazyOption(keys, lazy='joined', propagate_to_loaders=False), - strategies.LoadEagerFromAliasOption(keys, alias=alias) - ) - -@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated') -def defer(*keys): - """Return a ``MapperOption`` that will convert the column property of the - given name into a deferred load. - - Used with :meth:`~sqlalchemy.orm.query.Query.options`. - - """ - return strategies.DeferredOption(keys, defer=True) - -@sa_util.accepts_a_list_as_starargs(list_deprecation='deprecated') -def undefer(*keys): - """Return a ``MapperOption`` that will convert the column property of the - given name into a non-deferred (regular column) load. - - Used with :meth:`~sqlalchemy.orm.query.Query.options`. - - """ - return strategies.DeferredOption(keys, defer=False) - -def undefer_group(name): - """Return a ``MapperOption`` that will convert the given group of deferred - column properties into a non-deferred (regular column) load. - - Used with :meth:`~sqlalchemy.orm.query.Query.options`. - - """ - return strategies.UndeferGroupOption(name) +__go(locals()) diff --git a/sqlalchemy/orm/attributes.py b/sqlalchemy/orm/attributes.py index 887d9a9..fc81db7 100644 --- a/sqlalchemy/orm/attributes.py +++ b/sqlalchemy/orm/attributes.py @@ -1,120 +1,175 @@ -# attributes.py - manages object attributes -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# orm/attributes.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Defines SQLAlchemy's system of class instrumentation.. + +"""Defines instrumentation for class attributes and their interaction +with instances. This module is usually not directly visible to user applications, but defines a large part of the ORM's interactivity. -SQLA's instrumentation system is completely customizable, in which -case an understanding of the general mechanics of this module is helpful. -An example of full customization is in /examples/custom_attributes. """ import operator -from operator import attrgetter, itemgetter -import types -import weakref +from .. import util, event, inspection +from . import interfaces, collections, exc as orm_exc -from sqlalchemy import util -from sqlalchemy.orm import interfaces, collections, exc -import sqlalchemy.exceptions as sa_exc +from .base import instance_state, instance_dict, manager_of_class -# lazy imports -_entity_info = None -identity_equal = None -state = None +from .base import PASSIVE_NO_RESULT, ATTR_WAS_SET, ATTR_EMPTY, NO_VALUE,\ + NEVER_SET, NO_CHANGE, CALLABLES_OK, SQL_OK, RELATED_OBJECT_OK,\ + INIT_OK, NON_PERSISTENT_OK, LOAD_AGAINST_COMMITTED, PASSIVE_OFF,\ + PASSIVE_RETURN_NEVER_SET, PASSIVE_NO_INITIALIZE, PASSIVE_NO_FETCH,\ + PASSIVE_NO_FETCH_RELATED, PASSIVE_ONLY_PERSISTENT, NO_AUTOFLUSH +from .base import state_str, instance_str -PASSIVE_NO_RESULT = util.symbol('PASSIVE_NO_RESULT') -ATTR_WAS_SET = util.symbol('ATTR_WAS_SET') -NO_VALUE = util.symbol('NO_VALUE') -NEVER_SET = util.symbol('NEVER_SET') -# "passive" get settings -# TODO: the True/False values need to be factored out -# of the rest of ORM code -# don't fire off any callables, and don't initialize the attribute to -# an empty value -PASSIVE_NO_INITIALIZE = True #util.symbol('PASSIVE_NO_INITIALIZE') +@inspection._self_inspects +class QueryableAttribute(interfaces._MappedAttribute, + interfaces.InspectionAttr, + interfaces.PropComparator): + """Base class for :term:`descriptor` objects that intercept + attribute events on behalf of a :class:`.MapperProperty` + object. The actual :class:`.MapperProperty` is accessible + via the :attr:`.QueryableAttribute.property` + attribute. -# don't fire off any callables, but if no callables present -# then initialize to an empty value/collection -# this is used by backrefs. -PASSIVE_NO_FETCH = util.symbol('PASSIVE_NO_FETCH') -# fire callables/initialize as needed -PASSIVE_OFF = False #util.symbol('PASSIVE_OFF') + .. seealso:: -INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__' -"""Attribute, elects custom instrumentation when present on a mapped class. + :class:`.InstrumentedAttribute` -Allows a class to specify a slightly or wildly different technique for -tracking changes made to mapped attributes and collections. + :class:`.MapperProperty` -Only one instrumentation implementation is allowed in a given object -inheritance hierarchy. + :attr:`.Mapper.all_orm_descriptors` -The value of this attribute must be a callable and will be passed a class -object. The callable must return one of: + :attr:`.Mapper.attrs` + """ - - An instance of an interfaces.InstrumentationManager or subclass - - An object implementing all or some of InstrumentationManager (TODO) - - A dictionary of callables, implementing all or some of the above (TODO) - - An instance of a ClassManager or subclass + is_attribute = True -interfaces.InstrumentationManager is public API and will remain stable -between releases. ClassManager is not public and no guarantees are made -about stability. Caveat emptor. - -This attribute is consulted by the default SQLAlchemy instrumentation -resolution code. If custom finders are installed in the global -instrumentation_finders list, they may or may not choose to honor this -attribute. - -""" - -instrumentation_finders = [] -"""An extensible sequence of instrumentation implementation finding callables. - -Finders callables will be passed a class object. If None is returned, the -next finder in the sequence is consulted. Otherwise the return must be an -instrumentation factory that follows the same guidelines as -INSTRUMENTATION_MANAGER. - -By default, the only finder is find_native_user_instrumentation_hook, which -searches for INSTRUMENTATION_MANAGER. If all finders return None, standard -ClassManager instrumentation is used. - -""" - -class QueryableAttribute(interfaces.PropComparator): - - def __init__(self, key, impl=None, comparator=None, parententity=None): - """Construct an InstrumentedAttribute. - - comparator - a sql.Comparator to which class-level compare/math events will be sent - """ + def __init__(self, class_, key, impl=None, + comparator=None, parententity=None, + of_type=None): + self.class_ = class_ self.key = key self.impl = impl self.comparator = comparator - self.parententity = parententity + self._parententity = parententity + self._of_type = of_type - def get_history(self, instance, **kwargs): - return self.impl.get_history(instance_state(instance), instance_dict(instance), **kwargs) + manager = manager_of_class(class_) + # manager is None in the case of AliasedClass + if manager: + # propagate existing event listeners from + # immediate superclass + for base in manager._bases: + if key in base: + self.dispatch._update(base[key].dispatch) + + @util.memoized_property + def _supports_population(self): + return self.impl.supports_population + + def get_history(self, instance, passive=PASSIVE_OFF): + return self.impl.get_history(instance_state(instance), + instance_dict(instance), passive) def __selectable__(self): # TODO: conditionally attach this method based on clause_element ? return self + @util.memoized_property + def info(self): + """Return the 'info' dictionary for the underlying SQL element. + + The behavior here is as follows: + + * If the attribute is a column-mapped property, i.e. + :class:`.ColumnProperty`, which is mapped directly + to a schema-level :class:`.Column` object, this attribute + will return the :attr:`.SchemaItem.info` dictionary associated + with the core-level :class:`.Column` object. + + * If the attribute is a :class:`.ColumnProperty` but is mapped to + any other kind of SQL expression other than a :class:`.Column`, + the attribute will refer to the :attr:`.MapperProperty.info` + dictionary associated directly with the :class:`.ColumnProperty`, + assuming the SQL expression itself does not have its own ``.info`` + attribute (which should be the case, unless a user-defined SQL + construct has defined one). + + * If the attribute refers to any other kind of + :class:`.MapperProperty`, including :class:`.RelationshipProperty`, + the attribute will refer to the :attr:`.MapperProperty.info` + dictionary associated with that :class:`.MapperProperty`. + + * To access the :attr:`.MapperProperty.info` dictionary of the + :class:`.MapperProperty` unconditionally, including for a + :class:`.ColumnProperty` that's associated directly with a + :class:`.schema.Column`, the attribute can be referred to using + :attr:`.QueryableAttribute.property` attribute, as + ``MyClass.someattribute.property.info``. + + .. versionadded:: 0.8.0 + + .. seealso:: + + :attr:`.SchemaItem.info` + + :attr:`.MapperProperty.info` + + """ + return self.comparator.info + + @util.memoized_property + def parent(self): + """Return an inspection instance representing the parent. + + This will be either an instance of :class:`.Mapper` + or :class:`.AliasedInsp`, depending upon the nature + of the parent entity which this attribute is associated + with. + + """ + return inspection.inspect(self._parententity) + + @property + def expression(self): + return self.comparator.__clause_element__() + def __clause_element__(self): return self.comparator.__clause_element__() + def _query_clause_element(self): + """like __clause_element__(), but called specifically + by :class:`.Query` to allow special behavior.""" + + return self.comparator._query_clause_element() + + def adapt_to_entity(self, adapt_to_entity): + assert not self._of_type + return self.__class__(adapt_to_entity.entity, + self.key, impl=self.impl, + comparator=self.comparator.adapt_to_entity( + adapt_to_entity), + parententity=adapt_to_entity) + + def of_type(self, cls): + return QueryableAttribute( + self.class_, + self.key, + self.impl, + self.comparator.of_type(cls), + self._parententity, + of_type=cls) + def label(self, name): - return self.__clause_element__().label(name) + return self._query_clause_element().label(name) def operate(self, op, *other, **kwargs): return op(self.comparator, *other, **kwargs) @@ -123,32 +178,50 @@ class QueryableAttribute(interfaces.PropComparator): return op(other, self.comparator, **kwargs) def hasparent(self, state, optimistic=False): - return self.impl.hasparent(state, optimistic=optimistic) - + return self.impl.hasparent(state, optimistic=optimistic) is not False + def __getattr__(self, key): try: return getattr(self.comparator, key) except AttributeError: raise AttributeError( - 'Neither %r object nor %r object has an attribute %r' % ( - type(self).__name__, - type(self.comparator).__name__, + 'Neither %r object nor %r object associated with %s ' + 'has an attribute %r' % ( + type(self).__name__, + type(self.comparator).__name__, + self, key) ) - - def __str__(self): - return repr(self.parententity) + "." + self.property.key - @property + def __str__(self): + return "%s.%s" % (self.class_.__name__, self.key) + + @util.memoized_property def property(self): + """Return the :class:`.MapperProperty` associated with this + :class:`.QueryableAttribute`. + + + Return values here will commonly be instances of + :class:`.ColumnProperty` or :class:`.RelationshipProperty`. + + + """ return self.comparator.property class InstrumentedAttribute(QueryableAttribute): - """Public-facing descriptor, placed in the mapped class dictionary.""" + """Class bound instrumented attribute which adds basic + :term:`descriptor` methods. + + See :class:`.QueryableAttribute` for a description of most features. + + + """ def __set__(self, instance, value): - self.impl.set(instance_state(instance), instance_dict(instance), value, None) + self.impl.set(instance_state(instance), + instance_dict(instance), value, None) def __delete__(self, instance): self.impl.delete(instance_state(instance), instance_dict(instance)) @@ -156,67 +229,88 @@ class InstrumentedAttribute(QueryableAttribute): def __get__(self, instance, owner): if instance is None: return self - return self.impl.get(instance_state(instance), instance_dict(instance)) -class _ProxyImpl(object): - accepts_scalar_loader = False - expire_missing = True - - def __init__(self, key): - self.key = key + dict_ = instance_dict(instance) + if self._supports_population and self.key in dict_: + return dict_[self.key] + else: + return self.impl.get(instance_state(instance), dict_) -def proxied_attribute_factory(descriptor): - """Create an InstrumentedAttribute / user descriptor hybrid. - Returns a new InstrumentedAttribute type that delegates descriptor +def create_proxied_attribute(descriptor): + """Create an QueryableAttribute / user descriptor hybrid. + + Returns a new QueryableAttribute type that delegates descriptor behavior and getattr() to the given descriptor. """ - class Proxy(InstrumentedAttribute): - """A combination of InsturmentedAttribute and a regular descriptor.""" + # TODO: can move this to descriptor_props if the need for this + # function is removed from ext/hybrid.py - def __init__(self, key, descriptor, comparator, parententity): + class Proxy(QueryableAttribute): + """Presents the :class:`.QueryableAttribute` interface as a + proxy on top of a Python descriptor / :class:`.PropComparator` + combination. + + """ + + def __init__(self, class_, key, descriptor, + comparator, + adapt_to_entity=None, doc=None, + original_property=None): + self.class_ = class_ self.key = key - # maintain ProxiedAttribute.user_prop compatability. - self.descriptor = self.user_prop = descriptor + self.descriptor = descriptor + self.original_property = original_property self._comparator = comparator - self._parententity = parententity - self.impl = _ProxyImpl(key) + self._adapt_to_entity = adapt_to_entity + self.__doc__ = doc + + @property + def property(self): + return self.comparator.property @util.memoized_property def comparator(self): if util.callable(self._comparator): self._comparator = self._comparator() + if self._adapt_to_entity: + self._comparator = self._comparator.adapt_to_entity( + self._adapt_to_entity) return self._comparator + def adapt_to_entity(self, adapt_to_entity): + return self.__class__(adapt_to_entity.entity, + self.key, + self.descriptor, + self._comparator, + adapt_to_entity) + def __get__(self, instance, owner): - """Delegate __get__ to the original descriptor.""" if instance is None: - descriptor.__get__(instance, owner) return self - return descriptor.__get__(instance, owner) + else: + return self.descriptor.__get__(instance, owner) - def __set__(self, instance, value): - """Delegate __set__ to the original descriptor.""" - return descriptor.__set__(instance, value) - - def __delete__(self, instance): - """Delegate __delete__ to the original descriptor.""" - return descriptor.__delete__(instance) + def __str__(self): + return "%s.%s" % (self.class_.__name__, self.key) def __getattr__(self, attribute): - """Delegate __getattr__ to the original descriptor and/or comparator.""" - + """Delegate __getattr__ to the original descriptor and/or + comparator.""" + try: return getattr(descriptor, attribute) except AttributeError: try: - return getattr(self._comparator, attribute) + return getattr(self.comparator, attribute) except AttributeError: raise AttributeError( - 'Neither %r object nor %r object has an attribute %r' % ( - type(descriptor).__name__, - type(self._comparator).__name__, + 'Neither %r object nor %r object associated with %s ' + 'has an attribute %r' % ( + type(descriptor).__name__, + type(self.comparator).__name__, + self, attribute) ) @@ -227,19 +321,72 @@ def proxied_attribute_factory(descriptor): from_instance=descriptor) return Proxy +OP_REMOVE = util.symbol("REMOVE") +OP_APPEND = util.symbol("APPEND") +OP_REPLACE = util.symbol("REPLACE") + + +class Event(object): + """A token propagated throughout the course of a chain of attribute + events. + + Serves as an indicator of the source of the event and also provides + a means of controlling propagation across a chain of attribute + operations. + + The :class:`.Event` object is sent as the ``initiator`` argument + when dealing with the :meth:`.AttributeEvents.append`, + :meth:`.AttributeEvents.set`, + and :meth:`.AttributeEvents.remove` events. + + The :class:`.Event` object is currently interpreted by the backref + event handlers, and is used to control the propagation of operations + across two mutually-dependent attributes. + + .. versionadded:: 0.9.0 + + :var impl: The :class:`.AttributeImpl` which is the current event + initiator. + + :var op: The symbol :attr:`.OP_APPEND`, :attr:`.OP_REMOVE` or + :attr:`.OP_REPLACE`, indicating the source operation. + + """ + + __slots__ = 'impl', 'op', 'parent_token' + + def __init__(self, attribute_impl, op): + self.impl = attribute_impl + self.op = op + self.parent_token = self.impl.parent_token + + def __eq__(self, other): + return isinstance(other, Event) and \ + other.impl is self.impl and \ + other.op == self.op + + @property + def key(self): + return self.impl.key + + def hasparent(self, state): + return self.impl.hasparent(state) + + class AttributeImpl(object): """internal implementation for instrumented attributes.""" def __init__(self, class_, key, - callable_, trackparent=False, extension=None, - compare_function=None, active_history=False, - parent_token=None, expire_missing=True, - **kwargs): - """Construct an AttributeImpl. + callable_, dispatch, trackparent=False, extension=None, + compare_function=None, active_history=False, + parent_token=None, expire_missing=True, + send_modified_events=True, + **kwargs): + r"""Construct an AttributeImpl. \class_ associated class - + key string name of the attribute @@ -255,7 +402,8 @@ class AttributeImpl(object): extension a single or list of AttributeExtension object(s) which will - receive set/delete/append/remove/etc. events. + receive set/delete/append/remove/etc. events. Deprecated. + The event package is now used. compare_function a function that compares two values which are normally @@ -268,34 +416,63 @@ class AttributeImpl(object): parent_token Usually references the MapperProperty, used as a key for the hasparent() function to identify an "owning" attribute. - Allows multiple AttributeImpls to all match a single + Allows multiple AttributeImpls to all match a single owner attribute. - + expire_missing if False, don't add an "expiry" callable to this attribute - during state.expire_attributes(None), if no value is present + during state.expire_attributes(None), if no value is present for this key. - + + send_modified_events + if False, the InstanceState._modified_event method will have no + effect; this means the attribute will never show up as changed in a + history entry. """ self.class_ = class_ self.key = key self.callable_ = callable_ + self.dispatch = dispatch self.trackparent = trackparent self.parent_token = parent_token or self + self.send_modified_events = send_modified_events if compare_function is None: self.is_equal = operator.eq else: self.is_equal = compare_function - self.extensions = util.to_list(extension or []) - for e in self.extensions: - if e.active_history: - active_history = True - break - self.active_history = active_history + + # TODO: pass in the manager here + # instead of doing a lookup + attr = manager_of_class(class_)[key] + + for ext in util.to_list(extension or []): + ext._adapt_listener(attr, ext) + + if active_history: + self.dispatch._active_history = True + self.expire_missing = expire_missing - + + __slots__ = ( + 'class_', 'key', 'callable_', 'dispatch', 'trackparent', + 'parent_token', 'send_modified_events', 'is_equal', 'expire_missing' + ) + + def __str__(self): + return "%s.%s" % (self.class_.__name__, self.key) + + def _get_active_history(self): + """Backwards compat for impl.active_history""" + + return self.dispatch._active_history + + def _set_active_history(self, value): + self.dispatch._active_history = value + + active_history = property(_get_active_history, _set_active_history) + def hasparent(self, state, optimistic=False): - """Return the boolean value of a `hasparent` flag attached to + """Return the boolean value of a `hasparent` flag attached to the given state. The `optimistic` flag determines what the default return value @@ -310,310 +487,357 @@ class AttributeImpl(object): will also not have a `hasparent` flag. """ - return state.parents.get(id(self.parent_token), optimistic) + msg = "This AttributeImpl is not configured to track parents." + assert self.trackparent, msg - def sethasparent(self, state, value): + return state.parents.get(id(self.parent_token), optimistic) \ + is not False + + def sethasparent(self, state, parent_state, value): """Set a boolean flag on the given item corresponding to whether or not it is attached to a parent object via the attribute represented by this ``InstrumentedAttribute``. """ - state.parents[id(self.parent_token)] = value + msg = "This AttributeImpl is not configured to track parents." + assert self.trackparent, msg - def set_callable(self, state, callable_): - """Set a callable function for this attribute on the given object. + id_ = id(self.parent_token) + if value: + state.parents[id_] = parent_state + else: + if id_ in state.parents: + last_parent = state.parents[id_] - This callable will be executed when the attribute is next - accessed, and is assumed to construct part of the instances - previously stored state. When its value or values are loaded, - they will be established as part of the instance's *committed - state*. While *trackparent* information will be assembled for - these instances, attribute-level event handlers will not be - fired. + if last_parent is not False and \ + last_parent.key != parent_state.key: - The callable overrides the class level callable set in the - ``InstrumentedAttribute`` constructor. + if last_parent.obj() is None: + raise orm_exc.StaleDataError( + "Removing state %s from parent " + "state %s along attribute '%s', " + "but the parent record " + "has gone stale, can't be sure this " + "is the most recent parent." % + (state_str(state), + state_str(parent_state), + self.key)) - """ - state.callables[self.key] = callable_ + return + + state.parents[id_] = False def get_history(self, state, dict_, passive=PASSIVE_OFF): raise NotImplementedError() - def _get_callable(self, state): - if self.key in state.callables: - return state.callables[self.key] - elif self.callable_ is not None: - return self.callable_(state) - else: - return None + def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE): + """Return a list of tuples of (state, obj) + for all objects in this attribute's current state + + history. + + Only applies to object-based attributes. + + This is an inlining of existing functionality + which roughly corresponds to: + + get_state_history( + state, + key, + passive=PASSIVE_NO_INITIALIZE).sum() + + """ + raise NotImplementedError() def initialize(self, state, dict_): """Initialize the given state's attribute with an empty value.""" - dict_[self.key] = None - return None + value = None + for fn in self.dispatch.init_scalar: + ret = fn(state, value, dict_) + if ret is not ATTR_EMPTY: + value = ret + + return value def get(self, state, dict_, passive=PASSIVE_OFF): """Retrieve a value from the given object. - If a callable is assembled on this object's attribute, and passive is False, the callable will be executed and the resulting value will be set as the new value for this attribute. """ - - try: + if self.key in dict_: return dict_[self.key] - except KeyError: - # if no history, check for lazy callables, etc. - if state.committed_state.get(self.key, NEVER_SET) is NEVER_SET: - if passive is PASSIVE_NO_INITIALIZE: + else: + # if history present, don't load + key = self.key + if key not in state.committed_state or \ + state.committed_state[key] is NEVER_SET: + if not passive & CALLABLES_OK: return PASSIVE_NO_RESULT - - callable_ = self._get_callable(state) - if callable_ is not None: - #if passive is not PASSIVE_OFF: - # return PASSIVE_NO_RESULT - value = callable_(passive=passive) - if value is PASSIVE_NO_RESULT: - return value - elif value is not ATTR_WAS_SET: - return self.set_committed_value(state, dict_, value) - else: - if self.key not in dict_: - return self.get(state, dict_, passive=passive) - return dict_[self.key] - # Return a new, empty value - return self.initialize(state, dict_) + if key in state.expired_attributes: + value = state._load_expired(state, passive) + elif key in state.callables: + callable_ = state.callables[key] + value = callable_(state, passive) + elif self.callable_: + value = self.callable_(state, passive) + else: + value = ATTR_EMPTY + + if value is PASSIVE_NO_RESULT or value is NEVER_SET: + return value + elif value is ATTR_WAS_SET: + try: + return dict_[key] + except KeyError: + # TODO: no test coverage here. + raise KeyError( + "Deferred loader for attribute " + "%r failed to populate " + "correctly" % key) + elif value is not ATTR_EMPTY: + return self.set_committed_value(state, dict_, value) + + if not passive & INIT_OK: + return NEVER_SET + else: + # Return a new, empty value + return self.initialize(state, dict_) def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): self.set(state, dict_, value, initiator, passive=passive) def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): - self.set(state, dict_, None, initiator, passive=passive) + self.set(state, dict_, None, initiator, + passive=passive, check_old=value) - def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + self.set(state, dict_, None, initiator, + passive=passive, check_old=value, pop=True) + + def set(self, state, dict_, value, initiator, + passive=PASSIVE_OFF, check_old=None, pop=False): raise NotImplementedError() def get_committed_value(self, state, dict_, passive=PASSIVE_OFF): """return the unchanged value of this attribute""" if self.key in state.committed_state: - if state.committed_state[self.key] is NO_VALUE: + value = state.committed_state[self.key] + if value in (NO_VALUE, NEVER_SET): return None else: - return state.committed_state.get(self.key) + return value else: return self.get(state, dict_, passive=passive) def set_committed_value(self, state, dict_, value): """set an attribute value on the given instance and 'commit' it.""" - state.commit(dict_, [self.key]) - - state.callables.pop(self.key, None) - state.dict[self.key] = value - + dict_[self.key] = value + state._commit(dict_, [self.key]) return value + class ScalarAttributeImpl(AttributeImpl): """represents a scalar value-holding InstrumentedAttribute.""" accepts_scalar_loader = True uses_objects = False + supports_population = True + collection = False + + __slots__ = '_replace_token', '_append_token', '_remove_token' + + def __init__(self, *arg, **kw): + super(ScalarAttributeImpl, self).__init__(*arg, **kw) + self._replace_token = self._append_token = None + self._remove_token = None + + def _init_append_token(self): + self._replace_token = self._append_token = Event(self, OP_REPLACE) + return self._replace_token + + _init_append_or_replace_token = _init_append_token + + def _init_remove_token(self): + self._remove_token = Event(self, OP_REMOVE) + return self._remove_token def delete(self, state, dict_): # TODO: catch key errors, convert to attributeerror? - if self.active_history: - old = self.get(state, dict_) + if self.dispatch._active_history: + old = self.get(state, dict_, PASSIVE_RETURN_NEVER_SET) else: old = dict_.get(self.key, NO_VALUE) - if self.extensions: - self.fire_remove_event(state, dict_, old, None) - state.modified_event(dict_, self, False, old) + if self.dispatch.remove: + self.fire_remove_event(state, dict_, old, self._remove_token) + state._modified_event(dict_, self, old) del dict_[self.key] def get_history(self, state, dict_, passive=PASSIVE_OFF): - return History.from_attribute( - self, state, dict_.get(self.key, NO_VALUE)) + if self.key in dict_: + return History.from_scalar_attribute(self, state, dict_[self.key]) + else: + if passive & INIT_OK: + passive ^= INIT_OK + current = self.get(state, dict_, passive=passive) + if current is PASSIVE_NO_RESULT: + return HISTORY_BLANK + else: + return History.from_scalar_attribute(self, state, current) - def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): - if initiator is self: - return - - if self.active_history: - old = self.get(state, dict_) + def set(self, state, dict_, value, initiator, + passive=PASSIVE_OFF, check_old=None, pop=False): + if self.dispatch._active_history: + old = self.get(state, dict_, PASSIVE_RETURN_NEVER_SET) else: old = dict_.get(self.key, NO_VALUE) - if self.extensions: - value = self.fire_replace_event(state, dict_, value, old, initiator) - state.modified_event(dict_, self, False, old) + if self.dispatch.set: + value = self.fire_replace_event(state, dict_, + value, old, initiator) + state._modified_event(dict_, self, old) dict_[self.key] = value def fire_replace_event(self, state, dict_, value, previous, initiator): - for ext in self.extensions: - value = ext.set(state, value, previous, initiator or self) + for fn in self.dispatch.set: + value = fn( + state, value, previous, + initiator or self._replace_token or + self._init_append_or_replace_token()) return value def fire_remove_event(self, state, dict_, value, initiator): - for ext in self.extensions: - ext.remove(state, value, initiator or self) + for fn in self.dispatch.remove: + fn(state, value, + initiator or self._remove_token or self._init_remove_token()) @property def type(self): self.property.columns[0].type -class MutableScalarAttributeImpl(ScalarAttributeImpl): - """represents a scalar value-holding InstrumentedAttribute, which can detect - changes within the value itself. - """ - - uses_objects = False - - def __init__(self, class_, key, callable_, - class_manager, copy_function=None, - compare_function=None, **kwargs): - super(ScalarAttributeImpl, self).__init__( - class_, - key, - callable_, - compare_function=compare_function, - **kwargs) - class_manager.mutable_attributes.add(key) - if copy_function is None: - raise sa_exc.ArgumentError( - "MutableScalarAttributeImpl requires a copy function") - self.copy = copy_function - - def get_history(self, state, dict_, passive=PASSIVE_OFF): - if not dict_: - v = state.committed_state.get(self.key, NO_VALUE) - else: - v = dict_.get(self.key, NO_VALUE) - - return History.from_attribute( - self, state, v) - - def check_mutable_modified(self, state, dict_): - added, \ - unchanged, \ - deleted = self.get_history(state, dict_, passive=PASSIVE_NO_INITIALIZE) - return bool(added or deleted) - - def get(self, state, dict_, passive=PASSIVE_OFF): - if self.key not in state.mutable_dict: - ret = ScalarAttributeImpl.get(self, state, dict_, passive=passive) - if ret is not PASSIVE_NO_RESULT: - state.mutable_dict[self.key] = ret - return ret - else: - return state.mutable_dict[self.key] - - def delete(self, state, dict_): - ScalarAttributeImpl.delete(self, state, dict_) - state.mutable_dict.pop(self.key) - - def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): - if initiator is self: - return - - if self.extensions: - old = self.get(state, dict_) - value = self.fire_replace_event(state, dict_, value, old, initiator) - - state.modified_event(dict_, self, True, NEVER_SET) - dict_[self.key] = value - state.mutable_dict[self.key] = value - - class ScalarObjectAttributeImpl(ScalarAttributeImpl): - """represents a scalar-holding InstrumentedAttribute, + """represents a scalar-holding InstrumentedAttribute, where the target object is also instrumented. Adds events to delete/set operations. - + """ accepts_scalar_loader = False uses_objects = True + supports_population = True + collection = False - def __init__(self, class_, key, callable_, - trackparent=False, extension=None, copy_function=None, - compare_function=None, **kwargs): - super(ScalarObjectAttributeImpl, self).__init__( - class_, - key, - callable_, - trackparent=trackparent, - extension=extension, - compare_function=compare_function, - **kwargs) - if compare_function is None: - self.is_equal = identity_equal + __slots__ = () def delete(self, state, dict_): old = self.get(state, dict_) - self.fire_remove_event(state, dict_, old, self) + self.fire_remove_event( + state, dict_, old, + self._remove_token or self._init_remove_token()) del dict_[self.key] def get_history(self, state, dict_, passive=PASSIVE_OFF): if self.key in dict_: - return History.from_attribute(self, state, dict_[self.key]) + return History.from_object_attribute(self, state, dict_[self.key]) else: + if passive & INIT_OK: + passive ^= INIT_OK current = self.get(state, dict_, passive=passive) if current is PASSIVE_NO_RESULT: return HISTORY_BLANK else: - return History.from_attribute(self, state, current) + return History.from_object_attribute(self, state, current) - def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE): + if self.key in dict_: + current = dict_[self.key] + elif passive & CALLABLES_OK: + current = self.get(state, dict_, passive=passive) + else: + return [] + + # can't use __hash__(), can't use __eq__() here + if current is not None and \ + current is not PASSIVE_NO_RESULT and \ + current is not NEVER_SET: + ret = [(instance_state(current), current)] + else: + ret = [(None, None)] + + if self.key in state.committed_state: + original = state.committed_state[self.key] + if original is not None and \ + original is not PASSIVE_NO_RESULT and \ + original is not NEVER_SET and \ + original is not current: + + ret.append((instance_state(original), original)) + return ret + + def set(self, state, dict_, value, initiator, + passive=PASSIVE_OFF, check_old=None, pop=False): """Set a value on the given InstanceState. - `initiator` is the ``InstrumentedAttribute`` that initiated the - ``set()`` operation and is used to control the depth of a circular - setter operation. - """ - if initiator is self: - return - - if self.active_history: - old = self.get(state, dict_) + if self.dispatch._active_history: + old = self.get( + state, dict_, + passive=PASSIVE_ONLY_PERSISTENT | + NO_AUTOFLUSH | LOAD_AGAINST_COMMITTED) else: - old = self.get(state, dict_, passive=PASSIVE_NO_FETCH) - + old = self.get( + state, dict_, passive=PASSIVE_NO_FETCH ^ INIT_OK | + LOAD_AGAINST_COMMITTED) + + if check_old is not None and \ + old is not PASSIVE_NO_RESULT and \ + check_old is not old: + if pop: + return + else: + raise ValueError( + "Object %s not associated with %s on attribute '%s'" % ( + instance_str(check_old), + state_str(state), + self.key + )) + value = self.fire_replace_event(state, dict_, value, old, initiator) dict_[self.key] = value def fire_remove_event(self, state, dict_, value, initiator): if self.trackparent and value is not None: - self.sethasparent(instance_state(value), False) + self.sethasparent(instance_state(value), state, False) - for ext in self.extensions: - ext.remove(state, value, initiator or self) + for fn in self.dispatch.remove: + fn(state, value, initiator or + self._remove_token or self._init_remove_token()) - state.modified_event(dict_, self, False, value) + state._modified_event(dict_, self, value) def fire_replace_event(self, state, dict_, value, previous, initiator): if self.trackparent: if (previous is not value and - previous is not None and - previous is not PASSIVE_NO_RESULT): - self.sethasparent(instance_state(previous), False) + previous not in (None, PASSIVE_NO_RESULT, NEVER_SET)): + self.sethasparent(instance_state(previous), state, False) - for ext in self.extensions: - value = ext.set(state, value, previous, initiator or self) + for fn in self.dispatch.set: + value = fn( + state, value, previous, initiator or + self._replace_token or self._init_append_or_replace_token()) - state.modified_event(dict_, self, False, previous) + state._modified_event(dict_, self, previous) if self.trackparent: if value is not None: - self.sethasparent(instance_state(value), True) + self.sethasparent(instance_state(value), state, True) return value @@ -625,68 +849,132 @@ class CollectionAttributeImpl(AttributeImpl): InstrumentedCollectionAttribute holds an arbitrary, user-specified container object (defaulting to a list) and brokers access to the - CollectionAdapter, a "view" onto that object that presents consistent - bag semantics to the orm layer independent of the user data implementation. + CollectionAdapter, a "view" onto that object that presents consistent bag + semantics to the orm layer independent of the user data implementation. """ accepts_scalar_loader = False uses_objects = True + supports_population = True + collection = True - def __init__(self, class_, key, callable_, - typecallable=None, trackparent=False, extension=None, - copy_function=None, compare_function=None, **kwargs): + __slots__ = ( + 'copy', 'collection_factory', '_append_token', '_remove_token', + '_duck_typed_as' + ) + + def __init__(self, class_, key, callable_, dispatch, + typecallable=None, trackparent=False, extension=None, + copy_function=None, compare_function=None, **kwargs): super(CollectionAttributeImpl, self).__init__( - class_, - key, - callable_, - trackparent=trackparent, - extension=extension, - compare_function=compare_function, - **kwargs) + class_, + key, + callable_, dispatch, + trackparent=trackparent, + extension=extension, + compare_function=compare_function, + **kwargs) if copy_function is None: copy_function = self.__copy self.copy = copy_function self.collection_factory = typecallable + self._append_token = None + self._remove_token = None + self._duck_typed_as = util.duck_type_collection( + self.collection_factory()) + + if getattr(self.collection_factory, "_sa_linker", None): + + @event.listens_for(self, "init_collection") + def link(target, collection, collection_adapter): + collection._sa_linker(collection_adapter) + + @event.listens_for(self, "dispose_collection") + def unlink(target, collection, collection_adapter): + collection._sa_linker(None) + + def _init_append_token(self): + self._append_token = Event(self, OP_APPEND) + return self._append_token + + def _init_remove_token(self): + self._remove_token = Event(self, OP_REMOVE) + return self._remove_token def __copy(self, item): - return [y for y in list(collections.collection_adapter(item))] + return [y for y in collections.collection_adapter(item)] def get_history(self, state, dict_, passive=PASSIVE_OFF): current = self.get(state, dict_, passive=passive) if current is PASSIVE_NO_RESULT: return HISTORY_BLANK else: - return History.from_attribute(self, state, current) + return History.from_collection(self, state, current) + + def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE): + # NOTE: passive is ignored here at the moment + + if self.key not in dict_: + return [] + + current = dict_[self.key] + current = getattr(current, '_sa_adapter') + + if self.key in state.committed_state: + original = state.committed_state[self.key] + if original not in (NO_VALUE, NEVER_SET): + current_states = [((c is not None) and + instance_state(c) or None, c) + for c in current] + original_states = [((c is not None) and + instance_state(c) or None, c) + for c in original] + + current_set = dict(current_states) + original_set = dict(original_states) + + return \ + [(s, o) for s, o in current_states + if s not in original_set] + \ + [(s, o) for s, o in current_states + if s in original_set] + \ + [(s, o) for s, o in original_states + if s not in current_set] + + return [(instance_state(o), o) for o in current] def fire_append_event(self, state, dict_, value, initiator): - for ext in self.extensions: - value = ext.append(state, value, initiator or self) + for fn in self.dispatch.append: + value = fn( + state, value, + initiator or self._append_token or self._init_append_token()) - state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + state._modified_event(dict_, self, NEVER_SET, True) if self.trackparent and value is not None: - self.sethasparent(instance_state(value), True) + self.sethasparent(instance_state(value), state, True) return value def fire_pre_remove_event(self, state, dict_, initiator): - state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + state._modified_event(dict_, self, NEVER_SET, True) def fire_remove_event(self, state, dict_, value, initiator): if self.trackparent and value is not None: - self.sethasparent(instance_state(value), False) + self.sethasparent(instance_state(value), state, False) - for ext in self.extensions: - ext.remove(state, value, initiator or self) + for fn in self.dispatch.remove: + fn(state, value, + initiator or self._remove_token or self._init_remove_token()) - state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + state._modified_event(dict_, self, NEVER_SET, True) def delete(self, state, dict_): if self.key not in dict_: return - state.modified_event(dict_, self, True, NEVER_SET) + state._modified_event(dict_, self, NEVER_SET, True) collection = self.get_collection(state, state.dict) collection.clear_with_event() @@ -701,82 +989,103 @@ class CollectionAttributeImpl(AttributeImpl): return user_data def _initialize_collection(self, state): - return state.manager.initialize_collection( + + adapter, collection = state.manager.initialize_collection( self.key, state, self.collection_factory) - def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): - if initiator is self: - return + self.dispatch.init_collection(state, collection, adapter) + return adapter, collection + + def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): collection = self.get_collection(state, dict_, passive=passive) if collection is PASSIVE_NO_RESULT: value = self.fire_append_event(state, dict_, value, initiator) - assert self.key not in dict_, "Collection was loaded during event handling." - state.get_pending(self.key).append(value) + assert self.key not in dict_, \ + "Collection was loaded during event handling." + state._get_pending_mutation(self.key).append(value) else: collection.append_with_event(value, initiator) def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): - if initiator is self: - return - collection = self.get_collection(state, state.dict, passive=passive) if collection is PASSIVE_NO_RESULT: self.fire_remove_event(state, dict_, value, initiator) - assert self.key not in dict_, "Collection was loaded during event handling." - state.get_pending(self.key).remove(value) + assert self.key not in dict_, \ + "Collection was loaded during event handling." + state._get_pending_mutation(self.key).remove(value) else: collection.remove_with_event(value, initiator) - def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): - """Set a value on the given object. + def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + try: + # TODO: better solution here would be to add + # a "popper" role to collections.py to complement + # "remover". + self.remove(state, dict_, value, initiator, passive=passive) + except (ValueError, KeyError, IndexError): + pass - `initiator` is the ``InstrumentedAttribute`` that initiated the - ``set()`` operation and is used to control the depth of a circular - setter operation. - """ + def set(self, state, dict_, value, initiator=None, + passive=PASSIVE_OFF, pop=False, _adapt=True): + iterable = orig_iterable = value - if initiator is self: - return - - self._set_iterable( - state, dict_, value, - lambda adapter, i: adapter.adapt_like_to_iterable(i)) - - def _set_iterable(self, state, dict_, iterable, adapter=None): - """Set a collection value from an iterable of state-bearers. - - ``adapter`` is an optional callable invoked with a CollectionAdapter - and the iterable. Should return an iterable of state-bearing - instances suitable for appending via a CollectionAdapter. Can be used - for, e.g., adapting an incoming dictionary into an iterator of values - rather than keys. - - """ # pulling a new collection first so that an adaptation exception does # not trigger a lazy load of the old collection. new_collection, user_data = self._initialize_collection(state) - if adapter: - new_values = list(adapter(new_collection, iterable)) - else: - new_values = list(iterable) + if _adapt: + if new_collection._converter is not None: + iterable = new_collection._converter(iterable) + else: + setting_type = util.duck_type_collection(iterable) + receiving_type = self._duck_typed_as - old = self.get(state, dict_) + if setting_type is not receiving_type: + given = iterable is None and 'None' or \ + iterable.__class__.__name__ + wanted = self._duck_typed_as.__name__ + raise TypeError( + "Incompatible collection type: %s is not %s-like" % ( + given, wanted)) - # ignore re-assignment of the current collection, as happens - # implicitly with in-place operators (foo.collection |= other) - if old is iterable: + # If the object is an adapted collection, return the (iterable) + # adapter. + if hasattr(iterable, '_sa_iterator'): + iterable = iterable._sa_iterator() + elif setting_type is dict: + if util.py3k: + iterable = iterable.values() + else: + iterable = getattr( + iterable, 'itervalues', iterable.values)() + else: + iterable = iter(iterable) + new_values = list(iterable) + + old = self.get(state, dict_, passive=PASSIVE_ONLY_PERSISTENT) + if old is PASSIVE_NO_RESULT: + old = self.initialize(state, dict_) + elif old is orig_iterable: + # ignore re-assignment of the current collection, as happens + # implicitly with in-place operators (foo.collection |= other) return - state.modified_event(dict_, self, True, old) + # place a copy of "old" in state.committed_state + state._modified_event(dict_, self, old, True) - old_collection = self.get_collection(state, dict_, old) + old_collection = old._sa_adapter dict_[self.key] = user_data - collections.bulk_replace(new_values, old_collection, new_collection) - old_collection.unlink(old) + collections.bulk_replace( + new_values, old_collection, new_collection) + del old._sa_adapter + self.dispatch.dispose_collection(state, old, old_collection) + + def _invalidate_collection(self, collection): + adapter = getattr(collection, '_sa_adapter') + adapter.invalidated = True def set_committed_value(self, state, dict_, value): """Set an attribute value on the given instance and 'commit' it.""" @@ -784,21 +1093,18 @@ class CollectionAttributeImpl(AttributeImpl): collection, user_data = self._initialize_collection(state) if value: - for item in value: - collection.append_without_event(item) + collection.append_multiple_without_event(value) - state.callables.pop(self.key, None) state.dict[self.key] = user_data - state.commit(dict_, [self.key]) + state._commit(dict_, [self.key]) - if self.key in state.pending: - + if self.key in state._pending_mutations: # pending items exist. issue a modified event, # add/remove new items. - state.modified_event(dict_, self, True, user_data) + state._modified_event(dict_, self, user_data, True) - pending = state.pending.pop(self.key) + pending = state._pending_mutations.pop(self.key) added = pending.added_items removed = pending.deleted_items for item in added: @@ -808,7 +1114,8 @@ class CollectionAttributeImpl(AttributeImpl): return user_data - def get_collection(self, state, dict_, user_data=None, passive=PASSIVE_OFF): + def get_collection(self, state, dict_, + user_data=None, passive=PASSIVE_OFF): """Retrieve the CollectionAdapter associated with the given state. Creates a new CollectionAdapter if one does not exist. @@ -821,527 +1128,321 @@ class CollectionAttributeImpl(AttributeImpl): return getattr(user_data, '_sa_adapter') -class GenericBackrefExtension(interfaces.AttributeExtension): - """An extension which synchronizes a two-way relationship. - A typical two-way relationship is a parent object containing a list of - child objects, where each child object references the parent. The other - are two objects which contain scalar references to each other. +def backref_listeners(attribute, key, uselist): + """Apply listeners to synchronize a two-way relationship.""" - """ - - active_history = False - - def __init__(self, key): - self.key = key + # use easily recognizable names for stack traces - def set(self, state, child, oldchild, initiator): + parent_token = attribute.impl.parent_token + parent_impl = attribute.impl + + def _acceptable_key_err(child_state, initiator, child_impl): + raise ValueError( + "Bidirectional attribute conflict detected: " + 'Passing object %s to attribute "%s" ' + 'triggers a modify event on attribute "%s" ' + 'via the backref "%s".' % ( + state_str(child_state), + initiator.parent_token, + child_impl.parent_token, + attribute.impl.parent_token + ) + ) + + def emit_backref_from_scalar_set_event(state, child, oldchild, initiator): if oldchild is child: return child - - if oldchild is not None and oldchild is not PASSIVE_NO_RESULT: + if oldchild is not None and \ + oldchild is not PASSIVE_NO_RESULT and \ + oldchild is not NEVER_SET: # With lazy=None, there's no guarantee that the full collection is # present when updating via a backref. - old_state, old_dict = instance_state(oldchild), instance_dict(oldchild) - impl = old_state.get_impl(self.key) - try: - impl.remove(old_state, - old_dict, - state.obj(), - initiator, passive=PASSIVE_NO_FETCH) - except (ValueError, KeyError, IndexError): - pass - + old_state, old_dict = instance_state(oldchild),\ + instance_dict(oldchild) + impl = old_state.manager[key].impl + + if initiator.impl is not impl or \ + initiator.op not in (OP_REPLACE, OP_REMOVE): + impl.pop(old_state, + old_dict, + state.obj(), + parent_impl._append_token or + parent_impl._init_append_token(), + passive=PASSIVE_NO_FETCH) + if child is not None: - child_state, child_dict = instance_state(child), instance_dict(child) - child_state.get_impl(self.key).append( - child_state, - child_dict, - state.obj(), - initiator, passive=PASSIVE_NO_FETCH) + child_state, child_dict = instance_state(child),\ + instance_dict(child) + child_impl = child_state.manager[key].impl + if initiator.parent_token is not parent_token and \ + initiator.parent_token is not child_impl.parent_token: + _acceptable_key_err(state, initiator, child_impl) + elif initiator.impl is not child_impl or \ + initiator.op not in (OP_APPEND, OP_REPLACE): + child_impl.append( + child_state, + child_dict, + state.obj(), + initiator, + passive=PASSIVE_NO_FETCH) return child - def append(self, state, child, initiator): - child_state, child_dict = instance_state(child), instance_dict(child) - child_state.get_impl(self.key).append( - child_state, - child_dict, - state.obj(), - initiator, passive=PASSIVE_NO_FETCH) - return child - - def remove(self, state, child, initiator): - if child is not None: - child_state, child_dict = instance_state(child), instance_dict(child) - child_state.get_impl(self.key).remove( - child_state, - child_dict, - state.obj(), - initiator, passive=PASSIVE_NO_FETCH) - - -class Events(object): - def __init__(self): - self.original_init = object.__init__ - # Initialize to tuples instead of lists to minimize the memory - # footprint - self.on_init = () - self.on_init_failure = () - self.on_load = () - self.on_resurrect = () - - def run(self, event, *args): - for fn in getattr(self, event): - fn(*args) - - def add_listener(self, event, listener): - # not thread safe... problem? mb: nope - bucket = getattr(self, event) - if bucket == (): - setattr(self, event, [listener]) - else: - bucket.append(listener) - - def remove_listener(self, event, listener): - bucket = getattr(self, event) - bucket.remove(listener) - - -class ClassManager(dict): - """tracks state information at the class level.""" - - MANAGER_ATTR = '_sa_class_manager' - STATE_ATTR = '_sa_instance_state' - - event_registry_factory = Events - deferred_scalar_loader = None - - def __init__(self, class_): - self.class_ = class_ - self.factory = None # where we came from, for inheritance bookkeeping - self.info = {} - self.mapper = None - self.new_init = None - self.mutable_attributes = set() - self.local_attrs = {} - self.originals = {} - for base in class_.__mro__[-2:0:-1]: # reverse, skipping 1st and last - if not isinstance(base, type): - continue - cls_state = manager_of_class(base) - if cls_state: - self.update(cls_state) - self.events = self.event_registry_factory() - self.manage() - self._instrument_init() - - def _configure_create_arguments(self, - _source=None, - deferred_scalar_loader=None): - """Accept extra **kw arguments passed to create_manager_for_cls. - - The current contract of ClassManager and other managers is that they - take a single "cls" argument in their constructor (as per - test/orm/instrumentation.py InstrumentationCollisionTest). This - is to provide consistency with the current API of "class manager" - callables and such which may return various ClassManager and - ClassManager-like instances. So create_manager_for_cls sends - in ClassManager-specific arguments via this method once the - non-proxied ClassManager is available. - - """ - if _source: - deferred_scalar_loader = _source.deferred_scalar_loader - - if deferred_scalar_loader: - self.deferred_scalar_loader = deferred_scalar_loader - - def _subclass_manager(self, cls): - """Create a new ClassManager for a subclass of this ClassManager's class. - - This is called automatically when attributes are instrumented so that - the attributes can be propagated to subclasses against their own - class-local manager, without the need for mappers etc. to have already - pre-configured managers for the full class hierarchy. Mappers - can post-configure the auto-generated ClassManager when needed. - - """ - manager = manager_of_class(cls) - if manager is None: - manager = _create_manager_for_cls(cls, _source=self) - return manager - - def _instrument_init(self): - # TODO: self.class_.__init__ is often the already-instrumented - # __init__ from an instrumented superclass. We still need to make - # our own wrapper, but it would - # be nice to wrap the original __init__ and not our existing wrapper - # of such, since this adds method overhead. - self.events.original_init = self.class_.__init__ - self.new_init = _generate_init(self.class_, self) - self.install_member('__init__', self.new_init) - - def _uninstrument_init(self): - if self.new_init: - self.uninstall_member('__init__') - self.new_init = None - - def _create_instance_state(self, instance): - if self.mutable_attributes: - return state.MutableAttrInstanceState(instance, self) - else: - return state.InstanceState(instance, self) - - def manage(self): - """Mark this instance as the manager for its class.""" - - setattr(self.class_, self.MANAGER_ATTR, self) - - def dispose(self): - """Dissasociate this manager from its class.""" - - delattr(self.class_, self.MANAGER_ATTR) - - def manager_getter(self): - return attrgetter(self.MANAGER_ATTR) - - def instrument_attribute(self, key, inst, propagated=False): - if propagated: - if key in self.local_attrs: - return # don't override local attr with inherited attr - else: - self.local_attrs[key] = inst - self.install_descriptor(key, inst) - self[key] = inst - - for cls in self.class_.__subclasses__(): - manager = self._subclass_manager(cls) - manager.instrument_attribute(key, inst, True) - - def post_configure_attribute(self, key): - pass - - def uninstrument_attribute(self, key, propagated=False): - if key not in self: + def emit_backref_from_collection_append_event(state, child, initiator): + if child is None: return - if propagated: - if key in self.local_attrs: - return # don't get rid of local attr - else: - del self.local_attrs[key] - self.uninstall_descriptor(key) - del self[key] - if key in self.mutable_attributes: - self.mutable_attributes.remove(key) - for cls in self.class_.__subclasses__(): - manager = self._subclass_manager(cls) - manager.uninstrument_attribute(key, True) - def unregister(self): - """remove all instrumentation established by this ClassManager.""" - - self._uninstrument_init() + child_state, child_dict = instance_state(child), \ + instance_dict(child) + child_impl = child_state.manager[key].impl - self.mapper = self.events = None - self.info.clear() - - for key in list(self): - if key in self.local_attrs: - self.uninstrument_attribute(key) + if initiator.parent_token is not parent_token and \ + initiator.parent_token is not child_impl.parent_token: + _acceptable_key_err(state, initiator, child_impl) + elif initiator.impl is not child_impl or \ + initiator.op not in (OP_APPEND, OP_REPLACE): + child_impl.append( + child_state, + child_dict, + state.obj(), + initiator, + passive=PASSIVE_NO_FETCH) + return child - def install_descriptor(self, key, inst): - if key in (self.STATE_ATTR, self.MANAGER_ATTR): - raise KeyError("%r: requested attribute name conflicts with " - "instrumentation attribute of the same name." % key) - setattr(self.class_, key, inst) + def emit_backref_from_collection_remove_event(state, child, initiator): + if child is not None: + child_state, child_dict = instance_state(child),\ + instance_dict(child) + child_impl = child_state.manager[key].impl + if initiator.impl is not child_impl or \ + initiator.op not in (OP_REMOVE, OP_REPLACE): + child_impl.pop( + child_state, + child_dict, + state.obj(), + initiator, + passive=PASSIVE_NO_FETCH) - def uninstall_descriptor(self, key): - delattr(self.class_, key) + if uselist: + event.listen(attribute, "append", + emit_backref_from_collection_append_event, + retval=True, raw=True) + else: + event.listen(attribute, "set", + emit_backref_from_scalar_set_event, + retval=True, raw=True) + # TODO: need coverage in test/orm/ of remove event + event.listen(attribute, "remove", + emit_backref_from_collection_remove_event, + retval=True, raw=True) - def install_member(self, key, implementation): - if key in (self.STATE_ATTR, self.MANAGER_ATTR): - raise KeyError("%r: requested attribute name conflicts with " - "instrumentation attribute of the same name." % key) - self.originals.setdefault(key, getattr(self.class_, key, None)) - setattr(self.class_, key, implementation) +_NO_HISTORY = util.symbol('NO_HISTORY') +_NO_STATE_SYMBOLS = frozenset([ + id(PASSIVE_NO_RESULT), + id(NO_VALUE), + id(NEVER_SET)]) - def uninstall_member(self, key): - original = self.originals.pop(key, None) - if original is not None: - setattr(self.class_, key, original) +History = util.namedtuple("History", [ + "added", "unchanged", "deleted" +]) - def instrument_collection_class(self, key, collection_class): - return collections.prepare_instrumentation(collection_class) - def initialize_collection(self, key, state, factory): - user_data = factory() - adapter = collections.CollectionAdapter( - self.get_impl(key), state, user_data) - return adapter, user_data +class History(History): + """A 3-tuple of added, unchanged and deleted values, + representing the changes which have occurred on an instrumented + attribute. - def is_instrumented(self, key, search=False): - if search: - return key in self - else: - return key in self.local_attrs + The easiest way to get a :class:`.History` object for a particular + attribute on an object is to use the :func:`.inspect` function:: - def get_impl(self, key): - return self[key].impl + from sqlalchemy import inspect - @property - def attributes(self): - return self.itervalues() + hist = inspect(myobject).attrs.myattribute.history - ## InstanceState management + Each tuple member is an iterable sequence: - def new_instance(self, state=None): - instance = self.class_.__new__(self.class_) - setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance)) - return instance + * ``added`` - the collection of items added to the attribute (the first + tuple element). - def setup_instance(self, instance, state=None): - setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance)) - - def teardown_instance(self, instance): - delattr(instance, self.STATE_ATTR) - - def _new_state_if_none(self, instance): - """Install a default InstanceState if none is present. + * ``unchanged`` - the collection of items that have not changed on the + attribute (the second tuple element). - A private convenience method used by the __init__ decorator. - - """ - if hasattr(instance, self.STATE_ATTR): - return False - else: - state = self._create_instance_state(instance) - setattr(instance, self.STATE_ATTR, state) - return state - - def state_getter(self): - """Return a (instance) -> InstanceState callable. - - "state getter" callables should raise either KeyError or - AttributeError if no InstanceState could be found for the - instance. - """ - - return attrgetter(self.STATE_ATTR) - - def dict_getter(self): - return attrgetter('__dict__') - - def has_state(self, instance): - return hasattr(instance, self.STATE_ATTR) - - def has_parent(self, state, key, optimistic=False): - """TODO""" - return self.get_impl(key).hasparent(state, optimistic=optimistic) - - def __nonzero__(self): - """All ClassManagers are non-zero regardless of attribute state.""" - return True - - def __repr__(self): - return '<%s of %r at %x>' % ( - self.__class__.__name__, self.class_, id(self)) - -class _ClassInstrumentationAdapter(ClassManager): - """Adapts a user-defined InstrumentationManager to a ClassManager.""" - - def __init__(self, class_, override, **kw): - self._adapted = override - self._get_state = self._adapted.state_getter(class_) - self._get_dict = self._adapted.dict_getter(class_) - - ClassManager.__init__(self, class_, **kw) - - def manage(self): - self._adapted.manage(self.class_, self) - - def dispose(self): - self._adapted.dispose(self.class_) - - def manager_getter(self): - return self._adapted.manager_getter(self.class_) - - def instrument_attribute(self, key, inst, propagated=False): - ClassManager.instrument_attribute(self, key, inst, propagated) - if not propagated: - self._adapted.instrument_attribute(self.class_, key, inst) - - def post_configure_attribute(self, key): - self._adapted.post_configure_attribute(self.class_, key, self[key]) - - def install_descriptor(self, key, inst): - self._adapted.install_descriptor(self.class_, key, inst) - - def uninstall_descriptor(self, key): - self._adapted.uninstall_descriptor(self.class_, key) - - def install_member(self, key, implementation): - self._adapted.install_member(self.class_, key, implementation) - - def uninstall_member(self, key): - self._adapted.uninstall_member(self.class_, key) - - def instrument_collection_class(self, key, collection_class): - return self._adapted.instrument_collection_class( - self.class_, key, collection_class) - - def initialize_collection(self, key, state, factory): - delegate = getattr(self._adapted, 'initialize_collection', None) - if delegate: - return delegate(key, state, factory) - else: - return ClassManager.initialize_collection(self, key, state, factory) - - def new_instance(self, state=None): - instance = self.class_.__new__(self.class_) - self.setup_instance(instance, state) - return instance - - def _new_state_if_none(self, instance): - """Install a default InstanceState if none is present. - - A private convenience method used by the __init__ decorator. - """ - if self.has_state(instance): - return False - else: - return self.setup_instance(instance) - - def setup_instance(self, instance, state=None): - self._adapted.initialize_instance_dict(self.class_, instance) - - if state is None: - state = self._create_instance_state(instance) - - # the given instance is assumed to have no state - self._adapted.install_state(self.class_, instance, state) - return state - - def teardown_instance(self, instance): - self._adapted.remove_state(self.class_, instance) - - def has_state(self, instance): - try: - state = self._get_state(instance) - except exc.NO_STATE: - return False - else: - return True - - def state_getter(self): - return self._get_state - - def dict_getter(self): - return self._get_dict - -class History(tuple): - """A 3-tuple of added, unchanged and deleted values. - - Each tuple member is an iterable sequence. + * ``deleted`` - the collection of items that have been removed from the + attribute (the third tuple element). """ - __slots__ = () - - added = property(itemgetter(0)) - unchanged = property(itemgetter(1)) - deleted = property(itemgetter(2)) - - def __new__(cls, added, unchanged, deleted): - return tuple.__new__(cls, (added, unchanged, deleted)) - - def __nonzero__(self): + def __bool__(self): return self != HISTORY_BLANK - + __nonzero__ = __bool__ + + def empty(self): + """Return True if this :class:`.History` has no changes + and no existing, unchanged state. + + """ + + return not bool( + (self.added or self.deleted) + or self.unchanged + ) + def sum(self): + """Return a collection of added + unchanged + deleted.""" + return (self.added or []) +\ - (self.unchanged or []) +\ - (self.deleted or []) - + (self.unchanged or []) +\ + (self.deleted or []) + def non_deleted(self): + """Return a collection of added + unchanged.""" + return (self.added or []) +\ - (self.unchanged or []) - + (self.unchanged or []) + def non_added(self): + """Return a collection of unchanged + deleted.""" + return (self.unchanged or []) +\ - (self.deleted or []) - + (self.deleted or []) + def has_changes(self): + """Return True if this :class:`.History` has changes.""" + return bool(self.added or self.deleted) - + def as_state(self): return History( - [(c is not None and c is not PASSIVE_NO_RESULT) + [(c is not None) and instance_state(c) or None for c in self.added], - [(c is not None and c is not PASSIVE_NO_RESULT) + [(c is not None) and instance_state(c) or None for c in self.unchanged], - [(c is not None and c is not PASSIVE_NO_RESULT) + [(c is not None) and instance_state(c) or None for c in self.deleted], - ) - + ) + @classmethod - def from_attribute(cls, attribute, state, current): - original = state.committed_state.get(attribute.key, NEVER_SET) + def from_scalar_attribute(cls, attribute, state, current): + original = state.committed_state.get(attribute.key, _NO_HISTORY) - if hasattr(attribute, 'get_collection'): - current = attribute.get_collection(state, state.dict, current) - if original is NO_VALUE: - return cls(list(current), (), ()) - elif original is NEVER_SET: - return cls((), list(current), ()) + if original is _NO_HISTORY: + if current is NEVER_SET: + return cls((), (), ()) else: - current_set = util.IdentitySet(current) - original_set = util.IdentitySet(original) - - # ensure duplicates are maintained - return cls( - [x for x in current if x not in original_set], - [x for x in current if x in original_set], - [x for x in original if x not in current_set] - ) - else: - if current is NO_VALUE: - if (original is not None and - original is not NEVER_SET and - original is not NO_VALUE): - deleted = [original] - else: - deleted = () - return cls((), (), deleted) - elif original is NO_VALUE: - return cls([current], (), ()) - elif (original is NEVER_SET or - attribute.is_equal(current, original) is True): - # dont let ClauseElement expressions here trip things up return cls((), [current], ()) + # don't let ClauseElement expressions here trip things up + elif attribute.is_equal(current, original) is True: + return cls((), [current], ()) + else: + # current convention on native scalars is to not + # include information + # about missing previous value in "deleted", but + # we do include None, which helps in some primary + # key situations + if id(original) in _NO_STATE_SYMBOLS: + deleted = () + else: + deleted = [original] + if current is NEVER_SET: + return cls((), (), deleted) else: - if original is not None: - deleted = [original] - else: - deleted = () return cls([current], (), deleted) + @classmethod + def from_object_attribute(cls, attribute, state, current): + original = state.committed_state.get(attribute.key, _NO_HISTORY) + + if original is _NO_HISTORY: + if current is NO_VALUE or current is NEVER_SET: + return cls((), (), ()) + else: + return cls((), [current], ()) + elif current is original: + return cls((), [current], ()) + else: + # current convention on related objects is to not + # include information + # about missing previous value in "deleted", and + # to also not include None - the dependency.py rules + # ignore the None in any case. + if id(original) in _NO_STATE_SYMBOLS or original is None: + deleted = () + else: + deleted = [original] + if current is NO_VALUE or current is NEVER_SET: + return cls((), (), deleted) + else: + return cls([current], (), deleted) + + @classmethod + def from_collection(cls, attribute, state, current): + original = state.committed_state.get(attribute.key, _NO_HISTORY) + + if current is NO_VALUE or current is NEVER_SET: + return cls((), (), ()) + + current = getattr(current, '_sa_adapter') + if original in (NO_VALUE, NEVER_SET): + return cls(list(current), (), ()) + elif original is _NO_HISTORY: + return cls((), list(current), ()) + else: + + current_states = [((c is not None) and instance_state(c) + or None, c) + for c in current + ] + original_states = [((c is not None) and instance_state(c) + or None, c) + for c in original + ] + + current_set = dict(current_states) + original_set = dict(original_states) + + return cls( + [o for s, o in current_states if s not in original_set], + [o for s, o in current_states if s in original_set], + [o for s, o in original_states if s not in current_set] + ) + HISTORY_BLANK = History(None, None, None) -def get_history(obj, key, **kwargs): - """Return a History record for the given object and attribute key. - - obj is an instrumented object instance. An InstanceState - is accepted directly for backwards compatibility but - this usage is deprecated. - - """ - return get_state_history(instance_state(obj), key, **kwargs) -def get_state_history(state, key, **kwargs): - return state.get_history(key, **kwargs) +def get_history(obj, key, passive=PASSIVE_OFF): + """Return a :class:`.History` record for the given object + and attribute key. + + :param obj: an object whose class is instrumented by the + attributes package. + + :param key: string attribute name. + + :param passive: indicates loading behavior for the attribute + if the value is not already present. This is a + bitflag attribute, which defaults to the symbol + :attr:`.PASSIVE_OFF` indicating all necessary SQL + should be emitted. + + """ + if passive is True: + util.warn_deprecated("Passing True for 'passive' is deprecated. " + "Use attributes.PASSIVE_NO_INITIALIZE") + passive = PASSIVE_NO_INITIALIZE + elif passive is False: + util.warn_deprecated("Passing False for 'passive' is " + "deprecated. Use attributes.PASSIVE_OFF") + passive = PASSIVE_OFF + + return get_state_history(instance_state(obj), key, passive) + + +def get_state_history(state, key, passive=PASSIVE_OFF): + return state.get_history(key, passive) + def has_parent(cls, obj, key, optimistic=False): """TODO""" @@ -1349,37 +1450,22 @@ def has_parent(cls, obj, key, optimistic=False): state = instance_state(obj) return manager.has_parent(state, key, optimistic) -def register_class(class_, **kw): - """Register class instrumentation. - - Returns the existing or newly created class manager. - """ - - manager = manager_of_class(class_) - if manager is None: - manager = _create_manager_for_cls(class_, **kw) - return manager - -def unregister_class(class_): - """Unregister class instrumentation.""" - - instrumentation_registry.unregister(class_) def register_attribute(class_, key, **kw): - - proxy_property = kw.pop('proxy_property', None) - comparator = kw.pop('comparator', None) parententity = kw.pop('parententity', None) - register_descriptor(class_, key, proxy_property, comparator, parententity) - if not proxy_property: - register_attribute_impl(class_, key, **kw) - -def register_attribute_impl(class_, key, - uselist=False, callable_=None, - useobject=False, mutable_scalars=False, - impl_class=None, **kw): - + doc = kw.pop('doc', None) + desc = register_descriptor(class_, key, + comparator, parententity, doc=doc) + register_attribute_impl(class_, key, **kw) + return desc + + +def register_attribute_impl(class_, key, + uselist=False, callable_=None, + useobject=False, + impl_class=None, backref=None, **kw): + manager = manager_of_class(class_) if uselist: factory = kw.pop('typecallable', None) @@ -1388,69 +1474,80 @@ def register_attribute_impl(class_, key, else: typecallable = kw.pop('typecallable', None) + dispatch = manager[key].dispatch + if impl_class: - impl = impl_class(class_, key, typecallable, **kw) + impl = impl_class(class_, key, typecallable, dispatch, **kw) elif uselist: - impl = CollectionAttributeImpl(class_, key, callable_, + impl = CollectionAttributeImpl(class_, key, callable_, dispatch, typecallable=typecallable, **kw) elif useobject: - impl = ScalarObjectAttributeImpl(class_, key, callable_, **kw) - elif mutable_scalars: - impl = MutableScalarAttributeImpl(class_, key, callable_, - class_manager=manager, **kw) + impl = ScalarObjectAttributeImpl(class_, key, callable_, + dispatch, **kw) else: - impl = ScalarAttributeImpl(class_, key, callable_, **kw) + impl = ScalarAttributeImpl(class_, key, callable_, dispatch, **kw) manager[key].impl = impl - + + if backref: + backref_listeners(manager[key], backref, uselist) + manager.post_configure_attribute(key) - -def register_descriptor(class_, key, proxy_property=None, comparator=None, parententity=None, property_=None): + return manager[key] + + +def register_descriptor(class_, key, comparator=None, + parententity=None, doc=None): manager = manager_of_class(class_) - if proxy_property: - proxy_type = proxied_attribute_factory(proxy_property) - descriptor = proxy_type(key, proxy_property, comparator, parententity) - else: - descriptor = InstrumentedAttribute(key, comparator=comparator, parententity=parententity) + descriptor = InstrumentedAttribute(class_, key, comparator=comparator, + parententity=parententity) + + descriptor.__doc__ = doc manager.instrument_attribute(key, descriptor) + return descriptor + def unregister_attribute(class_, key): manager_of_class(class_).uninstrument_attribute(key) + def init_collection(obj, key): """Initialize a collection attribute and return the collection adapter. - + This function is used to provide direct access to collection internals for a previously unloaded attribute. e.g.:: - + collection_adapter = init_collection(someobject, 'elements') for elem in values: collection_adapter.append_without_event(elem) - - For an easier way to do the above, see :func:`~sqlalchemy.orm.attributes.set_committed_value`. - + + For an easier way to do the above, see + :func:`~sqlalchemy.orm.attributes.set_committed_value`. + obj is an instrumented object instance. An InstanceState - is accepted directly for backwards compatibility but + is accepted directly for backwards compatibility but this usage is deprecated. - + """ state = instance_state(obj) dict_ = state.dict return init_state_collection(state, dict_, key) - + + def init_state_collection(state, dict_, key): """Initialize a collection attribute and return the collection adapter.""" - - attr = state.get_impl(key) + + attr = state.manager[key].impl user_data = attr.initialize(state, dict_) return attr.get_collection(state, dict_, user_data) + def set_committed_value(instance, key, value): """Set the value of an attribute with no history events. - - Cancels any previous history present. The value should be + + Cancels any previous history present. The value should be a scalar value for scalar-holding attributes, or an iterable for any collection-holding attribute. @@ -1460,23 +1557,25 @@ def set_committed_value(instance, key, value): which has loaded additional attributes or collections through separate queries, which can then be attached to an instance as though it were part of its original loaded state. - + """ state, dict_ = instance_state(instance), instance_dict(instance) - state.get_impl(key).set_committed_value(state, dict_, value) - + state.manager[key].impl.set_committed_value(state, dict_, value) + + def set_attribute(instance, key, value): """Set the value of an attribute, firing history events. - + This function may be used regardless of instrumentation applied directly to the class, i.e. no descriptors are required. Custom attribute management schemes will need to make usage of this method to establish attribute state as understood by SQLAlchemy. - + """ state, dict_ = instance_state(instance), instance_dict(instance) - state.get_impl(key).set(state, dict_, value, None) + state.manager[key].impl.set(state, dict_, value, None) + def get_attribute(instance, key): """Get the value of an attribute, firing any callables required. @@ -1486,10 +1585,11 @@ def get_attribute(instance, key): Custom attribute management schemes will need to make usage of this method to make usage of attribute state as understood by SQLAlchemy. - + """ state, dict_ = instance_state(instance), instance_dict(instance) - return state.get_impl(key).get(state, dict_) + return state.manager[key].impl.get(state, dict_) + def del_attribute(instance, key): """Delete the value of an attribute, firing history events. @@ -1499,210 +1599,19 @@ def del_attribute(instance, key): Custom attribute management schemes will need to make usage of this method to establish attribute state as understood by SQLAlchemy. - + """ state, dict_ = instance_state(instance), instance_dict(instance) - state.get_impl(key).delete(state, dict_) + state.manager[key].impl.delete(state, dict_) + + +def flag_modified(instance, key): + """Mark an attribute on an instance as 'modified'. + + This sets the 'modified' flag on the instance and + establishes an unconditional change event for the given attribute. -def is_instrumented(instance, key): - """Return True if the given attribute on the given instance is instrumented - by the attributes package. - - This function may be used regardless of instrumentation - applied directly to the class, i.e. no descriptors are required. - """ - return manager_of_class(instance.__class__).is_instrumented(key, search=True) - -class InstrumentationRegistry(object): - """Private instrumentation registration singleton. - - All classes are routed through this registry - when first instrumented, however the InstrumentationRegistry - is not actually needed unless custom ClassManagers are in use. - - """ - - _manager_finders = weakref.WeakKeyDictionary() - _state_finders = util.WeakIdentityMapping() - _dict_finders = util.WeakIdentityMapping() - _extended = False - - def create_manager_for_cls(self, class_, **kw): - assert class_ is not None - assert manager_of_class(class_) is None - - for finder in instrumentation_finders: - factory = finder(class_) - if factory is not None: - break - else: - factory = ClassManager - - existing_factories = self._collect_management_factories_for(class_).\ - difference([factory]) - if existing_factories: - raise TypeError( - "multiple instrumentation implementations specified " - "in %s inheritance hierarchy: %r" % ( - class_.__name__, list(existing_factories))) - - manager = factory(class_) - if not isinstance(manager, ClassManager): - manager = _ClassInstrumentationAdapter(class_, manager) - - if factory != ClassManager and not self._extended: - # somebody invoked a custom ClassManager. - # reinstall global "getter" functions with the more - # expensive ones. - self._extended = True - _install_lookup_strategy(self) - - manager._configure_create_arguments(**kw) - - manager.factory = factory - self._manager_finders[class_] = manager.manager_getter() - self._state_finders[class_] = manager.state_getter() - self._dict_finders[class_] = manager.dict_getter() - return manager - - def _collect_management_factories_for(self, cls): - """Return a collection of factories in play or specified for a hierarchy. - - Traverses the entire inheritance graph of a cls and returns a collection - of instrumentation factories for those classes. Factories are extracted - from active ClassManagers, if available, otherwise - instrumentation_finders is consulted. - - """ - hierarchy = util.class_hierarchy(cls) - factories = set() - for member in hierarchy: - manager = manager_of_class(member) - if manager is not None: - factories.add(manager.factory) - else: - for finder in instrumentation_finders: - factory = finder(member) - if factory is not None: - break - else: - factory = None - factories.add(factory) - factories.discard(None) - return factories - - def manager_of_class(self, cls): - # this is only called when alternate instrumentation has been established - if cls is None: - return None - try: - finder = self._manager_finders[cls] - except KeyError: - return None - else: - return finder(cls) - - def state_of(self, instance): - # this is only called when alternate instrumentation has been established - if instance is None: - raise AttributeError("None has no persistent state.") - try: - return self._state_finders[instance.__class__](instance) - except KeyError: - raise AttributeError("%r is not instrumented" % instance.__class__) - - def dict_of(self, instance): - # this is only called when alternate instrumentation has been established - if instance is None: - raise AttributeError("None has no persistent state.") - try: - return self._dict_finders[instance.__class__](instance) - except KeyError: - raise AttributeError("%r is not instrumented" % instance.__class__) - - def unregister(self, class_): - if class_ in self._manager_finders: - manager = self.manager_of_class(class_) - manager.unregister() - manager.dispose() - del self._manager_finders[class_] - del self._state_finders[class_] - del self._dict_finders[class_] - if ClassManager.MANAGER_ATTR in class_.__dict__: - delattr(class_, ClassManager.MANAGER_ATTR) - -instrumentation_registry = InstrumentationRegistry() - -def _install_lookup_strategy(implementation): - """Replace global class/object management functions - with either faster or more comprehensive implementations, - based on whether or not extended class instrumentation - has been detected. - - This function is called only by InstrumentationRegistry() - and unit tests specific to this behavior. - - """ - global instance_state, instance_dict, manager_of_class - if implementation is util.symbol('native'): - instance_state = attrgetter(ClassManager.STATE_ATTR) - instance_dict = attrgetter("__dict__") - def manager_of_class(cls): - return cls.__dict__.get(ClassManager.MANAGER_ATTR, None) - else: - instance_state = instrumentation_registry.state_of - instance_dict = instrumentation_registry.dict_of - manager_of_class = instrumentation_registry.manager_of_class - -_create_manager_for_cls = instrumentation_registry.create_manager_for_cls - -# Install default "lookup" strategies. These are basically -# very fast attrgetters for key attributes. -# When a custom ClassManager is installed, more expensive per-class -# strategies are copied over these. -_install_lookup_strategy(util.symbol('native')) - -def find_native_user_instrumentation_hook(cls): - """Find user-specified instrumentation management for a class.""" - return getattr(cls, INSTRUMENTATION_MANAGER, None) -instrumentation_finders.append(find_native_user_instrumentation_hook) - -def _generate_init(class_, class_manager): - """Build an __init__ decorator that triggers ClassManager events.""" - - # TODO: we should use the ClassManager's notion of the - # original '__init__' method, once ClassManager is fixed - # to always reference that. - original__init__ = class_.__init__ - assert original__init__ - - # Go through some effort here and don't change the user's __init__ - # calling signature. - # FIXME: need to juggle local names to avoid constructor argument - # clashes. - func_body = """\ -def __init__(%(apply_pos)s): - new_state = class_manager._new_state_if_none(%(self_arg)s) - if new_state: - return new_state.initialize_instance(%(apply_kw)s) - else: - return original__init__(%(apply_kw)s) -""" - func_vars = util.format_argspec_init(original__init__, grouped=False) - func_text = func_body % func_vars - - # Py3K - #func_defaults = getattr(original__init__, '__defaults__', None) - # Py2K - func = getattr(original__init__, 'im_func', original__init__) - func_defaults = getattr(func, 'func_defaults', None) - # end Py2K - - env = locals().copy() - exec func_text in env - __init__ = env['__init__'] - __init__.__doc__ = original__init__.__doc__ - if func_defaults: - __init__.func_defaults = func_defaults - return __init__ + state, dict_ = instance_state(instance), instance_dict(instance) + impl = state.manager[key].impl + state._modified_event(dict_, impl, NO_VALUE, force=True) diff --git a/sqlalchemy/orm/collections.py b/sqlalchemy/orm/collections.py index 616f251..2bb53e6 100644 --- a/sqlalchemy/orm/collections.py +++ b/sqlalchemy/orm/collections.py @@ -1,3 +1,10 @@ +# orm/collections.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + """Support for collections of mapped entities. The collections package supplies the machinery used to inform the ORM of @@ -6,7 +13,7 @@ used, allowing arbitrary types (including built-ins) to be used as entity collections without requiring inheritance from a base class. Instrumentation decoration relays membership change events to the -``InstrumentedCollectionAttribute`` that is currently managing the collection. +:class:`.CollectionAttributeImpl` that is currently managing the collection. The decorators observe function call arguments and return values, tracking entities entering or leaving the collection. Two decorator approaches are provided. One is a bundle of generic decorators that map function arguments @@ -91,21 +98,20 @@ instrumentation may be the answer. Within your method, ``collection_adapter(self)`` will retrieve an object that you can use for explicit control over triggering append and remove events. -The owning object and InstrumentedCollectionAttribute are also reachable +The owning object and :class:`.CollectionAttributeImpl` are also reachable through the adapter, allowing for some very sophisticated behavior. """ -import copy import inspect import operator -import sys import weakref -import sqlalchemy.exceptions as sa_exc -from sqlalchemy.sql import expression -from sqlalchemy import schema, util +from ..sql import expression +from .. import util, exc as sa_exc +from . import base +from sqlalchemy.util.compat import inspect_getargspec __all__ = ['collection', 'collection_adapter', 'mapped_collection', 'column_mapped_collection', @@ -114,11 +120,116 @@ __all__ = ['collection', 'collection_adapter', __instrumentation_mutex = util.threading.Lock() +class _PlainColumnGetter(object): + """Plain column getter, stores collection of Column objects + directly. + + Serializes to a :class:`._SerializableColumnGetterV2` + which has more expensive __call__() performance + and some rare caveats. + + """ + + def __init__(self, cols): + self.cols = cols + self.composite = len(cols) > 1 + + def __reduce__(self): + return _SerializableColumnGetterV2._reduce_from_cols(self.cols) + + def _cols(self, mapper): + return self.cols + + def __call__(self, value): + state = base.instance_state(value) + m = base._state_mapper(state) + + key = [ + m._get_state_attr_by_column(state, state.dict, col) + for col in self._cols(m) + ] + + if self.composite: + return tuple(key) + else: + return key[0] + + +class _SerializableColumnGetter(object): + """Column-based getter used in version 0.7.6 only. + + Remains here for pickle compatibility with 0.7.6. + + """ + + def __init__(self, colkeys): + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__(self): + return _SerializableColumnGetter, (self.colkeys,) + + def __call__(self, value): + state = base.instance_state(value) + m = base._state_mapper(state) + key = [m._get_state_attr_by_column( + state, state.dict, + m.mapped_table.columns[k]) + for k in self.colkeys] + if self.composite: + return tuple(key) + else: + return key[0] + + +class _SerializableColumnGetterV2(_PlainColumnGetter): + """Updated serializable getter which deals with + multi-table mapped classes. + + Two extremely unusual cases are not supported. + Mappings which have tables across multiple metadata + objects, or which are mapped to non-Table selectables + linked across inheriting mappers may fail to function + here. + + """ + + def __init__(self, colkeys): + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__(self): + return self.__class__, (self.colkeys,) + + @classmethod + def _reduce_from_cols(cls, cols): + def _table_key(c): + if not isinstance(c.table, expression.TableClause): + return None + else: + return c.table.key + colkeys = [(c.key, _table_key(c)) for c in cols] + return _SerializableColumnGetterV2, (colkeys,) + + def _cols(self, mapper): + cols = [] + metadata = getattr(mapper.local_table, 'metadata', None) + for (ckey, tkey) in self.colkeys: + if tkey is None or \ + metadata is None or \ + tkey not in metadata: + cols.append(mapper.local_table.c[ckey]) + else: + cols.append(metadata.tables[tkey].c[ckey]) + return cols + + def column_mapped_collection(mapping_spec): """A dictionary-based collection type with column-based keying. - Returns a MappedCollection factory with a keying function generated - from mapping_spec, which may be a Column or a sequence of Columns. + Returns a :class:`.MappedCollection` factory with a keying function + generated from mapping_spec, which may be a Column or a sequence + of Columns. The key value must be immutable for the lifetime of the object. You can not, for example, map on foreign key values if those key values will @@ -126,29 +237,31 @@ def column_mapped_collection(mapping_spec): after a session flush. """ - from sqlalchemy.orm.util import _state_mapper - from sqlalchemy.orm.attributes import instance_state - - cols = [expression._no_literals(q) for q in util.to_list(mapping_spec)] - if len(cols) == 1: - def keyfunc(value): - state = instance_state(value) - m = _state_mapper(state) - return m._get_state_attr_by_column(state, cols[0]) - else: - mapping_spec = tuple(cols) - def keyfunc(value): - state = instance_state(value) - m = _state_mapper(state) - return tuple(m._get_state_attr_by_column(state, c) - for c in mapping_spec) + cols = [expression._only_column_elements(q, "mapping_spec") + for q in util.to_list(mapping_spec) + ] + keyfunc = _PlainColumnGetter(cols) return lambda: MappedCollection(keyfunc) + +class _SerializableAttrGetter(object): + def __init__(self, name): + self.name = name + self.getter = operator.attrgetter(name) + + def __call__(self, target): + return self.getter(target) + + def __reduce__(self): + return _SerializableAttrGetter, (self.name, ) + + def attribute_mapped_collection(attr_name): """A dictionary-based collection type with attribute-based keying. - Returns a MappedCollection factory with a keying based on the - 'attr_name' attribute of entities in the collection. + Returns a :class:`.MappedCollection` factory with a keying based on the + 'attr_name' attribute of entities in the collection, where ``attr_name`` + is the string name of the attribute. The key value must be immutable for the lifetime of the object. You can not, for example, map on foreign key values if those key values will @@ -156,14 +269,16 @@ def attribute_mapped_collection(attr_name): after a session flush. """ - return lambda: MappedCollection(operator.attrgetter(attr_name)) + getter = _SerializableAttrGetter(attr_name) + return lambda: MappedCollection(getter) def mapped_collection(keyfunc): """A dictionary-based collection type with arbitrary keying. - Returns a MappedCollection factory with a keying function generated - from keyfunc, a callable that takes an entity and returns a key value. + Returns a :class:`.MappedCollection` factory with a keying function + generated from keyfunc, a callable that takes an entity and returns a + key value. The key value must be immutable for the lifetime of the object. You can not, for example, map on foreign key values if those key values will @@ -173,13 +288,14 @@ def mapped_collection(keyfunc): """ return lambda: MappedCollection(keyfunc) + class collection(object): """Decorators for entity collection classes. The decorators fall into two groups: annotations and interception recipes. - The annotating decorators (appender, remover, iterator, - internally_instrumented, on_link) indicate the method's purpose and take no + The annotating decorators (appender, remover, iterator, linker, converter, + internally_instrumented) indicate the method's purpose and take no arguments. They are not written with parens:: @collection.appender @@ -188,16 +304,12 @@ class collection(object): The recipe decorators all require parens, even those that take no arguments:: - @collection.adds('entity'): + @collection.adds('entity') def insert(self, position, entity): ... @collection.removes_return() def popitem(self): ... - Decorators can be specified in long-hand for Python 2.3, or with - the class-level dict attribute '__instrumentation__'- see the source - for details. - """ # Bundled as a class solely for ease of use: packaging, doc strings, # importability. @@ -243,7 +355,7 @@ class collection(object): promulgation to collection events. """ - setattr(fn, '_sa_instrument_role', 'appender') + fn._sa_instrument_role = 'appender' return fn @staticmethod @@ -252,7 +364,7 @@ class collection(object): The remover method is called with one positional argument: the value to remove. The method will be automatically decorated with - 'removes_return()' if not already decorated:: + :meth:`removes_return` if not already decorated:: @collection.remover def zap(self, entity): ... @@ -270,7 +382,7 @@ class collection(object): promulgation to collection events. """ - setattr(fn, '_sa_instrument_role', 'remover') + fn._sa_instrument_role = 'remover' return fn @staticmethod @@ -284,17 +396,18 @@ class collection(object): def __iter__(self): ... """ - setattr(fn, '_sa_instrument_role', 'iterator') + fn._sa_instrument_role = 'iterator' return fn @staticmethod def internally_instrumented(fn): """Tag the method as instrumented. - This tag will prevent any decoration from being applied to the method. - Use this if you are orchestrating your own calls to collection_adapter - in one of the basic SQLAlchemy interface methods, or to prevent - an automatic ABC method decoration from wrapping your implementation:: + This tag will prevent any decoration from being applied to the + method. Use this if you are orchestrating your own calls to + :func:`.collection_adapter` in one of the basic SQLAlchemy + interface methods, or to prevent an automatic ABC method + decoration from wrapping your implementation:: # normally an 'extend' method on a list-like class would be # automatically intercepted and re-implemented in terms of @@ -304,12 +417,12 @@ class collection(object): def extend(self, items): ... """ - setattr(fn, '_sa_instrumented', True) + fn._sa_instrumented = True return fn @staticmethod - def on_link(fn): - """Tag the method as a the "linked to attribute" event handler. + def linker(fn): + """Tag the method as a "linked to attribute" event handler. This optional event handler will be called when the collection class is linked to or unlinked from the InstrumentedAttribute. It is @@ -317,10 +430,17 @@ class collection(object): the instance. A single argument is passed: the collection adapter that has been linked, or None if unlinking. + .. deprecated:: 1.0.0 - the :meth:`.collection.linker` handler + is superseded by the :meth:`.AttributeEvents.init_collection` + and :meth:`.AttributeEvents.dispose_collection` handlers. + """ - setattr(fn, '_sa_instrument_role', 'on_link') + fn._sa_instrument_role = 'linker' return fn + link = linker + """deprecated; synonym for :meth:`.collection.linker`.""" + @staticmethod def converter(fn): """Tag the method as the collection converter. @@ -333,12 +453,12 @@ class collection(object): The converter method will receive the object being assigned and should return an iterable of values suitable for use by the ``appender`` method. A converter must not assign values or mutate the collection, - it's sole job is to adapt the value the user provides into an iterable + its sole job is to adapt the value the user provides into an iterable of values for the ORM's use. The default converter implementation will use duck-typing to do the conversion. A dict-like collection will be convert into an iterable - of dictionary values, and other types will simply be iterated. + of dictionary values, and other types will simply be iterated:: @collection.converter def convert(self, other): ... @@ -351,7 +471,7 @@ class collection(object): validation on the values about to be assigned. """ - setattr(fn, '_sa_instrument_role', 'converter') + fn._sa_instrument_role = 'converter' return fn @staticmethod @@ -371,7 +491,7 @@ class collection(object): """ def decorator(fn): - setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) + fn._sa_instrument_before = ('fire_append_event', arg) return fn return decorator @@ -391,8 +511,8 @@ class collection(object): """ def decorator(fn): - setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) - setattr(fn, '_sa_instrument_after', 'fire_remove_event') + fn._sa_instrument_before = ('fire_append_event', arg) + fn._sa_instrument_after = 'fire_remove_event' return fn return decorator @@ -413,7 +533,7 @@ class collection(object): """ def decorator(fn): - setattr(fn, '_sa_instrument_before', ('fire_remove_event', arg)) + fn._sa_instrument_before = ('fire_remove_event', arg) return fn return decorator @@ -421,9 +541,9 @@ class collection(object): def removes_return(): """Mark the method as removing an entity in the collection. - Adds "remove from collection" handling to the method. The return value - of the method, if any, is considered the value to remove. The method - arguments are not inspected:: + Adds "remove from collection" handling to the method. The return + value of the method, if any, is considered the value to remove. The + method arguments are not inspected:: @collection.removes_return() def pop(self): ... @@ -433,30 +553,13 @@ class collection(object): """ def decorator(fn): - setattr(fn, '_sa_instrument_after', 'fire_remove_event') + fn._sa_instrument_after = 'fire_remove_event' return fn return decorator -# public instrumentation interface for 'internally instrumented' -# implementations -def collection_adapter(collection): - """Fetch the CollectionAdapter for a collection.""" - return getattr(collection, '_sa_adapter', None) - -def collection_iter(collection): - """Iterate over an object supporting the @iterator or __iter__ protocols. - - If the collection is an ORM collection, it need not be attached to an - object to be iterable. - - """ - try: - return getattr(collection, '_sa_iterator', - getattr(collection, '__iter__'))() - except AttributeError: - raise TypeError("'%s' object is not iterable" % - type(collection).__name__) +collection_adapter = operator.attrgetter('_sa_adapter') +"""Fetch the :class:`.CollectionAdapter` for a collection.""" class CollectionAdapter(object): @@ -466,126 +569,115 @@ class CollectionAdapter(object): to the underlying Python collection, and emits add/remove events for entities entering or leaving the collection. - The ORM uses an CollectionAdapter exclusively for interaction with + The ORM uses :class:`.CollectionAdapter` exclusively for interaction with entity collections. + """ + + __slots__ = ( + 'attr', '_key', '_data', 'owner_state', '_converter', 'invalidated') + def __init__(self, attr, owner_state, data): self.attr = attr - # TODO: figure out what this being a weakref buys us + self._key = attr.key self._data = weakref.ref(data) self.owner_state = owner_state - self.link_to_self(data) + data._sa_adapter = self + self._converter = data._sa_converter + self.invalidated = False - data = property(lambda s: s._data(), - doc="The entity collection being adapted.") + def _warn_invalidated(self): + util.warn("This collection has been invalidated.") - def link_to_self(self, data): - """Link a collection to this adapter, and fire a link event.""" - setattr(data, '_sa_adapter', self) - if hasattr(data, '_sa_on_link'): - getattr(data, '_sa_on_link')(self) + @property + def data(self): + "The entity collection being adapted." + return self._data() - def unlink(self, data): - """Unlink a collection from any adapter, and fire a link event.""" - setattr(data, '_sa_adapter', None) - if hasattr(data, '_sa_on_link'): - getattr(data, '_sa_on_link')(None) + @property + def _referenced_by_owner(self): + """return True if the owner state still refers to this collection. - def adapt_like_to_iterable(self, obj): - """Converts collection-compatible objects to an iterable of values. - - Can be passed any type of object, and if the underlying collection - determines that it can be adapted into a stream of values it can - use, returns an iterable of values suitable for append()ing. - - This method may raise TypeError or any other suitable exception - if adaptation fails. - - If a converter implementation is not supplied on the collection, - a default duck-typing-based implementation is used. + This will return False within a bulk replace operation, + where this collection is the one being replaced. """ - converter = getattr(self._data(), '_sa_converter', None) - if converter is not None: - return converter(obj) + return self.owner_state.dict[self._key] is self._data() - setting_type = util.duck_type_collection(obj) - receiving_type = util.duck_type_collection(self._data()) - - if obj is None or setting_type != receiving_type: - given = obj is None and 'None' or obj.__class__.__name__ - if receiving_type is None: - wanted = self._data().__class__.__name__ - else: - wanted = receiving_type.__name__ - - raise TypeError( - "Incompatible collection type: %s is not %s-like" % ( - given, wanted)) - - # If the object is an adapted collection, return the (iterable) - # adapter. - if getattr(obj, '_sa_adapter', None) is not None: - return getattr(obj, '_sa_adapter') - elif setting_type == dict: - # Py3K - #return obj.values() - # Py2K - return getattr(obj, 'itervalues', getattr(obj, 'values'))() - # end Py2K - else: - return iter(obj) + def bulk_appender(self): + return self._data()._sa_appender def append_with_event(self, item, initiator=None): """Add an entity to the collection, firing mutation events.""" - getattr(self._data(), '_sa_appender')(item, _sa_initiator=initiator) + + self._data()._sa_appender(item, _sa_initiator=initiator) def append_without_event(self, item): """Add or restore an entity to the collection, firing no events.""" - getattr(self._data(), '_sa_appender')(item, _sa_initiator=False) + self._data()._sa_appender(item, _sa_initiator=False) + + def append_multiple_without_event(self, items): + """Add or restore an entity to the collection, firing no events.""" + appender = self._data()._sa_appender + for item in items: + appender(item, _sa_initiator=False) + + def bulk_remover(self): + return self._data()._sa_remover def remove_with_event(self, item, initiator=None): """Remove an entity from the collection, firing mutation events.""" - getattr(self._data(), '_sa_remover')(item, _sa_initiator=initiator) + self._data()._sa_remover(item, _sa_initiator=initiator) def remove_without_event(self, item): """Remove an entity from the collection, firing no events.""" - getattr(self._data(), '_sa_remover')(item, _sa_initiator=False) + self._data()._sa_remover(item, _sa_initiator=False) def clear_with_event(self, initiator=None): """Empty the collection, firing a mutation event for each entity.""" + + remover = self._data()._sa_remover for item in list(self): - self.remove_with_event(item, initiator) + remover(item, _sa_initiator=initiator) def clear_without_event(self): """Empty the collection, firing no events.""" + + remover = self._data()._sa_remover for item in list(self): - self.remove_without_event(item) + remover(item, _sa_initiator=False) def __iter__(self): """Iterate over entities in the collection.""" - - # Py3K requires iter() here - return iter(getattr(self._data(), '_sa_iterator')()) + + return iter(self._data()._sa_iterator()) def __len__(self): """Count entities in the collection.""" - return len(list(getattr(self._data(), '_sa_iterator')())) + return len(list(self._data()._sa_iterator())) - def __nonzero__(self): + def __bool__(self): return True + __nonzero__ = __bool__ + def fire_append_event(self, item, initiator=None): """Notify that a entity has entered the collection. - Initiator is the InstrumentedAttribute that initiated the membership - mutation, and should be left as None unless you are passing along - an initiator value from a chained operation. + Initiator is a token owned by the InstrumentedAttribute that + initiated the membership mutation, and should be left as None + unless you are passing along an initiator value from a chained + operation. """ - if initiator is not False and item is not None: - return self.attr.fire_append_event(self.owner_state, self.owner_state.dict, item, initiator) + if initiator is not False: + if self.invalidated: + self._warn_invalidated() + return self.attr.fire_append_event( + self.owner_state, + self.owner_state.dict, + item, initiator) else: return item @@ -597,8 +689,13 @@ class CollectionAdapter(object): an initiator value from a chained operation. """ - if initiator is not False and item is not None: - self.attr.fire_remove_event(self.owner_state, self.owner_state.dict, item, initiator) + if initiator is not False: + if self.invalidated: + self._warn_invalidated() + self.attr.fire_remove_event( + self.owner_state, + self.owner_state.dict, + item, initiator) def fire_pre_remove_event(self, initiator=None): """Notify that an entity is about to be removed from the collection. @@ -607,17 +704,28 @@ class CollectionAdapter(object): fire_remove_event(). """ - self.attr.fire_pre_remove_event(self.owner_state, self.owner_state.dict, initiator=initiator) + if self.invalidated: + self._warn_invalidated() + self.attr.fire_pre_remove_event( + self.owner_state, + self.owner_state.dict, + initiator=initiator) def __getstate__(self): - return {'key': self.attr.key, + return {'key': self._key, 'owner_state': self.owner_state, - 'data': self.data} + 'owner_cls': self.owner_state.class_, + 'data': self.data, + 'invalidated': self.invalidated} def __setstate__(self, d): - self.attr = getattr(d['owner_state'].obj().__class__, d['key']).impl + self._key = d['key'] self.owner_state = d['owner_state'] self._data = weakref.ref(d['data']) + self._converter = d['data']._sa_converter + d['data']._sa_adapter = self + self.invalidated = d['invalidated'] + self.attr = getattr(d['owner_cls'], self._key).impl def bulk_replace(values, existing_adapter, new_adapter): @@ -628,34 +736,38 @@ def bulk_replace(values, existing_adapter, new_adapter): instances in ``existing_adapter`` not present in ``values`` will have remove events fired upon them. - values - An iterable of collection member instances + :param values: An iterable of collection member instances - existing_adapter - A CollectionAdapter of instances to be replaced + :param existing_adapter: A :class:`.CollectionAdapter` of + instances to be replaced - new_adapter - An empty CollectionAdapter to load with ``values`` + :param new_adapter: An empty :class:`.CollectionAdapter` + to load with ``values`` """ - if not isinstance(values, list): - values = list(values) + + assert isinstance(values, list) idset = util.IdentitySet - constants = idset(existing_adapter or ()).intersection(values or ()) + existing_idset = idset(existing_adapter or ()) + constants = existing_idset.intersection(values or ()) additions = idset(values or ()).difference(constants) - removals = idset(existing_adapter or ()).difference(constants) + removals = existing_idset.difference(constants) + + appender = new_adapter.bulk_appender() for member in values or (): if member in additions: - new_adapter.append_with_event(member) + appender(member) elif member in constants: - new_adapter.append_without_event(member) + appender(member, _sa_initiator=False) if existing_adapter: + remover = existing_adapter.bulk_remover() for member in removals: - existing_adapter.remove_with_event(member) + remover(member) + def prepare_instrumentation(factory): """Prepare a callable for future use as a collection class factory. @@ -678,7 +790,7 @@ def prepare_instrumentation(factory): # Did factory callable return a builtin? if cls in __canned_instrumentation: # Wrap it so that it returns our 'Instrumented*' - factory = __converting_factory(factory) + factory = __converting_factory(cls, factory) cls = factory() # Instrument the class if needed. @@ -691,51 +803,29 @@ def prepare_instrumentation(factory): return factory -def __converting_factory(original_factory): - """Convert the type returned by collection factories on the fly. - Given a collection factory that returns a builtin type (e.g. a list), - return a wrapped function that converts that type to one of our - instrumented types. +def __converting_factory(specimen_cls, original_factory): + """Return a wrapper that converts a "canned" collection like + set, dict, list into the Instrumented* version. """ + + instrumented_cls = __canned_instrumentation[specimen_cls] + def wrapper(): collection = original_factory() - type_ = type(collection) - if type_ in __canned_instrumentation: - # return an instrumented type initialized from the factory's - # collection - return __canned_instrumentation[type_](collection) - else: - raise sa_exc.InvalidRequestError( - "Collection class factories must produce instances of a " - "single class.") - try: - # often flawed but better than nothing - wrapper.__name__ = "%sWrapper" % original_factory.__name__ - wrapper.__doc__ = original_factory.__doc__ - except: - pass + return instrumented_cls(collection) + + # often flawed but better than nothing + wrapper.__name__ = "%sWrapper" % original_factory.__name__ + wrapper.__doc__ = original_factory.__doc__ + return wrapper + def _instrument_class(cls): """Modify methods in a class and install instrumentation.""" - # TODO: more formally document this as a decoratorless/Python 2.3 - # option for specifying instrumentation. (likely doc'd here in code only, - # not in online docs.) Useful for C types too. - # - # __instrumentation__ = { - # 'rolename': 'methodname', # ... - # 'methods': { - # 'methodname': ('fire_{append,remove}_event', argspec, - # 'fire_{append,remove}_event'), - # 'append': ('fire_append_event', 1, None), - # '__setitem__': ('fire_append_event', 1, 'fire_remove_event'), - # 'pop': (None, None, 'fire_remove_event'), - # } - # } - # In the normal call flow, a request for any of the 3 basic collection # types is transformed into one of our trivial subclasses # (e.g. InstrumentedList). Catch anything else that sneaks in here... @@ -744,55 +834,80 @@ def _instrument_class(cls): "Can not instrument a built-in type. Use a " "subclass, even a trivial one.") + roles, methods = _locate_roles_and_methods(cls) + + _setup_canned_roles(cls, roles, methods) + + _assert_required_roles(cls, roles, methods) + + _set_collection_attributes(cls, roles, methods) + + +def _locate_roles_and_methods(cls): + """search for _sa_instrument_role-decorated methods in + method resolution order, assign to roles. + + """ + + roles = {} + methods = {} + + for supercls in cls.__mro__: + for name, method in vars(supercls).items(): + if not util.callable(method): + continue + + # note role declarations + if hasattr(method, '_sa_instrument_role'): + role = method._sa_instrument_role + assert role in ('appender', 'remover', 'iterator', + 'linker', 'converter') + roles.setdefault(role, name) + + # transfer instrumentation requests from decorated function + # to the combined queue + before, after = None, None + if hasattr(method, '_sa_instrument_before'): + op, argument = method._sa_instrument_before + assert op in ('fire_append_event', 'fire_remove_event') + before = op, argument + if hasattr(method, '_sa_instrument_after'): + op = method._sa_instrument_after + assert op in ('fire_append_event', 'fire_remove_event') + after = op + if before: + methods[name] = before + (after, ) + elif after: + methods[name] = None, None, after + return roles, methods + + +def _setup_canned_roles(cls, roles, methods): + """see if this class has "canned" roles based on a known + collection type (dict, set, list). Apply those roles + as needed to the "roles" dictionary, and also + prepare "decorator" methods + + """ collection_type = util.duck_type_collection(cls) if collection_type in __interfaces: - roles = __interfaces[collection_type].copy() - decorators = roles.pop('_decorators', {}) - else: - roles, decorators = {}, {} + canned_roles, decorators = __interfaces[collection_type] + for role, name in canned_roles.items(): + roles.setdefault(role, name) - if hasattr(cls, '__instrumentation__'): - roles.update(copy.deepcopy(getattr(cls, '__instrumentation__'))) + # apply ABC auto-decoration to methods that need it + for method, decorator in decorators.items(): + fn = getattr(cls, method, None) + if (fn and method not in methods and + not hasattr(fn, '_sa_instrumented')): + setattr(cls, method, decorator(fn)) - methods = roles.pop('methods', {}) - for name in dir(cls): - method = getattr(cls, name, None) - if not util.callable(method): - continue +def _assert_required_roles(cls, roles, methods): + """ensure all roles are present, and apply implicit instrumentation if + needed - # note role declarations - if hasattr(method, '_sa_instrument_role'): - role = method._sa_instrument_role - assert role in ('appender', 'remover', 'iterator', - 'on_link', 'converter') - roles[role] = name - - # transfer instrumentation requests from decorated function - # to the combined queue - before, after = None, None - if hasattr(method, '_sa_instrument_before'): - op, argument = method._sa_instrument_before - assert op in ('fire_append_event', 'fire_remove_event') - before = op, argument - if hasattr(method, '_sa_instrument_after'): - op = method._sa_instrument_after - assert op in ('fire_append_event', 'fire_remove_event') - after = op - if before: - methods[name] = before[0], before[1], after - elif after: - methods[name] = None, None, after - - # apply ABC auto-decoration to methods that need it - for method, decorator in decorators.items(): - fn = getattr(cls, method, None) - if (fn and method not in methods and - not hasattr(fn, '_sa_instrumented')): - setattr(cls, method, decorator(fn)) - - # ensure all roles are present, and apply implicit instrumentation if - # needed + """ if 'appender' not in roles or not hasattr(cls, roles['appender']): raise sa_exc.ArgumentError( "Type %s must elect an appender method to be " @@ -814,24 +929,34 @@ def _instrument_class(cls): "Type %s must elect an iterator method to be " "a collection class" % cls.__name__) - # apply ad-hoc instrumentation from decorators, class-level defaults - # and implicit role declarations - for method, (before, argument, after) in methods.items(): - setattr(cls, method, - _instrument_membership_mutator(getattr(cls, method), + +def _set_collection_attributes(cls, roles, methods): + """apply ad-hoc instrumentation from decorators, class-level defaults + and implicit role declarations + + """ + for method_name, (before, argument, after) in methods.items(): + setattr(cls, method_name, + _instrument_membership_mutator(getattr(cls, method_name), before, argument, after)) # intern the role map - for role, method in roles.items(): - setattr(cls, '_sa_%s' % role, getattr(cls, method)) + for role, method_name in roles.items(): + setattr(cls, '_sa_%s' % role, getattr(cls, method_name)) + + cls._sa_adapter = None + + if not hasattr(cls, '_sa_converter'): + cls._sa_converter = None + cls._sa_instrumented = id(cls) - setattr(cls, '_sa_instrumented', id(cls)) def _instrument_membership_mutator(method, before, argument, after): - """Route method args and/or return value through the collection adapter.""" + """Route method args and/or return value through the collection + adapter.""" # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))' if before: - fn_args = list(util.flatten_iterator(inspect.getargspec(method)[0])) - if type(argument) is int: + fn_args = list(util.flatten_iterator(inspect_getargspec(method)[0])) + if isinstance(argument, int): pos_arg = argument named_arg = len(fn_args) > argument and fn_args[argument] or None else: @@ -862,7 +987,7 @@ def _instrument_membership_mutator(method, before, argument, after): if initiator is False: executor = None else: - executor = getattr(args[0], '_sa_adapter', None) + executor = args[0]._sa_adapter if before and executor: getattr(executor, before)(value, initiator) @@ -874,42 +999,46 @@ def _instrument_membership_mutator(method, before, argument, after): if res is not None: getattr(executor, after)(res, initiator) return res - try: - wrapper._sa_instrumented = True - wrapper.__name__ = method.__name__ - wrapper.__doc__ = method.__doc__ - except: - pass + + wrapper._sa_instrumented = True + if hasattr(method, "_sa_instrument_role"): + wrapper._sa_instrument_role = method._sa_instrument_role + wrapper.__name__ = method.__name__ + wrapper.__doc__ = method.__doc__ return wrapper + def __set(collection, item, _sa_initiator=None): """Run set events, may eventually be inlined into decorators.""" - if _sa_initiator is not False and item is not None: - executor = getattr(collection, '_sa_adapter', None) + if _sa_initiator is not False: + executor = collection._sa_adapter if executor: - item = getattr(executor, 'fire_append_event')(item, _sa_initiator) + item = executor.fire_append_event(item, _sa_initiator) return item - + + def __del(collection, item, _sa_initiator=None): """Run del events, may eventually be inlined into decorators.""" - if _sa_initiator is not False and item is not None: - executor = getattr(collection, '_sa_adapter', None) + if _sa_initiator is not False: + executor = collection._sa_adapter if executor: - getattr(executor, 'fire_remove_event')(item, _sa_initiator) + executor.fire_remove_event(item, _sa_initiator) + def __before_delete(collection, _sa_initiator=None): """Special method to run 'commit existing value' methods""" - executor = getattr(collection, '_sa_adapter', None) + executor = collection._sa_adapter if executor: - getattr(executor, 'fire_pre_remove_event')(_sa_initiator) + executor.fire_pre_remove_event(_sa_initiator) + def _list_decorators(): """Tailored instrumentation wrappers for any list-like class.""" def _tidy(fn): - setattr(fn, '_sa_instrumented', True) - fn.__doc__ = getattr(getattr(list, fn.__name__), '__doc__') + fn._sa_instrumented = True + fn.__doc__ = getattr(list, fn.__name__).__doc__ def append(fn): def append(self, item, _sa_initiator=None): @@ -948,19 +1077,22 @@ def _list_decorators(): start = index.start or 0 if start < 0: start += len(self) - stop = index.stop or len(self) + if index.stop is not None: + stop = index.stop + else: + stop = len(self) if stop < 0: stop += len(self) - + if step == 1: - for i in xrange(start, stop, step): + for i in range(start, stop, step): if len(self) > start: del self[start] - + for i, item in enumerate(value): self.insert(i + start, item) else: - rng = range(start, stop, step) + rng = list(range(start, stop, step)) if len(value) != len(rng): raise ValueError( "attempt to assign sequence of size %s to " @@ -987,25 +1119,24 @@ def _list_decorators(): _tidy(__delitem__) return __delitem__ - # Py2K - def __setslice__(fn): - def __setslice__(self, start, end, values): - for value in self[start:end]: - __del(self, value) - values = [__set(self, value) for value in values] - fn(self, start, end, values) - _tidy(__setslice__) - return __setslice__ + if util.py2k: + def __setslice__(fn): + def __setslice__(self, start, end, values): + for value in self[start:end]: + __del(self, value) + values = [__set(self, value) for value in values] + fn(self, start, end, values) + _tidy(__setslice__) + return __setslice__ + + def __delslice__(fn): + def __delslice__(self, start, end): + for value in self[start:end]: + __del(self, value) + fn(self, start, end) + _tidy(__delslice__) + return __delslice__ - def __delslice__(fn): - def __delslice__(self, start, end): - for value in self[start:end]: - __del(self, value) - fn(self, start, end) - _tidy(__delslice__) - return __delslice__ - # end Py2K - def extend(fn): def extend(self, iterable): for value in iterable: @@ -1015,8 +1146,8 @@ def _list_decorators(): def __iadd__(fn): def __iadd__(self, iterable): - # list.__iadd__ takes any iterable and seems to let TypeError raise - # as-is instead of returning NotImplemented + # list.__iadd__ takes any iterable and seems to let TypeError + # raise as-is instead of returning NotImplemented for value in iterable: self.append(value) return self @@ -1032,6 +1163,15 @@ def _list_decorators(): _tidy(pop) return pop + if not util.py2k: + def clear(fn): + def clear(self, index=-1): + for item in self: + __del(self, item) + fn(self) + _tidy(clear) + return clear + # __imul__ : not wrapping this. all members of the collection are already # present, so no need to fire appends... wrapping it with an explicit # decorator is still possible, so events on *= can be had if they're @@ -1041,12 +1181,13 @@ def _list_decorators(): l.pop('_tidy') return l + def _dict_decorators(): """Tailored instrumentation wrappers for any dict-like mapping class.""" def _tidy(fn): - setattr(fn, '_sa_instrumented', True) - fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__') + fn._sa_instrumented = True + fn.__doc__ = getattr(dict, fn.__name__).__doc__ Unspecified = util.symbol('Unspecified') @@ -1105,48 +1246,38 @@ def _dict_decorators(): _tidy(setdefault) return setdefault - if sys.version_info < (2, 4): - def update(fn): - def update(self, other): - for key in other.keys(): - if key not in self or self[key] is not other[key]: - self[key] = other[key] - _tidy(update) - return update - else: - def update(fn): - def update(self, __other=Unspecified, **kw): - if __other is not Unspecified: - if hasattr(__other, 'keys'): - for key in __other.keys(): - if (key not in self or + def update(fn): + def update(self, __other=Unspecified, **kw): + if __other is not Unspecified: + if hasattr(__other, 'keys'): + for key in list(__other): + if (key not in self or self[key] is not __other[key]): - self[key] = __other[key] - else: - for key, value in __other: - if key not in self or self[key] is not value: - self[key] = value - for key in kw: - if key not in self or self[key] is not kw[key]: - self[key] = kw[key] - _tidy(update) - return update + self[key] = __other[key] + else: + for key, value in __other: + if key not in self or self[key] is not value: + self[key] = value + for key in kw: + if key not in self or self[key] is not kw[key]: + self[key] = kw[key] + _tidy(update) + return update l = locals().copy() l.pop('_tidy') l.pop('Unspecified') return l -if util.py3k: - _set_binop_bases = (set, frozenset) -else: - import sets - _set_binop_bases = (set, frozenset, sets.BaseSet) +_set_binop_bases = (set, frozenset) + def _set_binops_check_strict(self, obj): - """Allow only set, frozenset and self.__class__-derived objects in binops.""" + """Allow only set, frozenset and self.__class__-derived + objects in binops.""" return isinstance(obj, _set_binop_bases + (self.__class__,)) + def _set_binops_check_loose(self, obj): """Allow anything set-like to participate in set binops.""" return (isinstance(obj, _set_binop_bases + (self.__class__,)) or @@ -1157,8 +1288,8 @@ def _set_decorators(): """Tailored instrumentation wrappers for any set-like class.""" def _tidy(fn): - setattr(fn, '_sa_instrumented', True) - fn.__doc__ = getattr(getattr(set, fn.__name__), '__doc__') + fn._sa_instrumented = True + fn.__doc__ = getattr(set, fn.__name__).__doc__ Unspecified = util.symbol('Unspecified') @@ -1171,23 +1302,15 @@ def _set_decorators(): _tidy(add) return add - if sys.version_info < (2, 4): - def discard(fn): - def discard(self, value, _sa_initiator=None): - if value in self: - self.remove(value, _sa_initiator) - _tidy(discard) - return discard - else: - def discard(fn): - def discard(self, value, _sa_initiator=None): + def discard(fn): + def discard(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ + if value in self: + __del(self, value, _sa_initiator) # testlib.pragma exempt:__hash__ - if value in self: - __del(self, value, _sa_initiator) - # testlib.pragma exempt:__hash__ - fn(self, value) - _tidy(discard) - return discard + fn(self, value) + _tidy(discard) + return discard def remove(fn): def remove(self, value, _sa_initiator=None): @@ -1312,72 +1435,53 @@ def _set_decorators(): class InstrumentedList(list): """An instrumented version of the built-in list.""" - __instrumentation__ = { - 'appender': 'append', - 'remover': 'remove', - 'iterator': '__iter__', } class InstrumentedSet(set): """An instrumented version of the built-in set.""" - __instrumentation__ = { - 'appender': 'add', - 'remover': 'remove', - 'iterator': '__iter__', } class InstrumentedDict(dict): """An instrumented version of the built-in dict.""" - # Py3K - #__instrumentation__ = { - # 'iterator': 'values', } - # Py2K - __instrumentation__ = { - 'iterator': 'itervalues', } - # end Py2K - + __canned_instrumentation = { list: InstrumentedList, set: InstrumentedSet, dict: InstrumentedDict, - } +} __interfaces = { - list: {'appender': 'append', + list: ( + {'appender': 'append', 'remover': 'remove', + 'iterator': '__iter__'}, _list_decorators() + ), + + set: ({'appender': 'add', 'remover': 'remove', - 'iterator': '__iter__', - '_decorators': _list_decorators(), }, - set: {'appender': 'add', - 'remover': 'remove', - 'iterator': '__iter__', - '_decorators': _set_decorators(), }, + 'iterator': '__iter__'}, _set_decorators() + ), + # decorators are required for dicts and object collections. - # Py3K - #dict: {'iterator': 'values', - # '_decorators': _dict_decorators(), }, - # Py2K - dict: {'iterator': 'itervalues', - '_decorators': _dict_decorators(), }, - # end Py2K - # < 0.4 compatible naming, deprecated- use decorators instead. - None: {} - } + dict: ({'iterator': 'values'}, _dict_decorators()) if util.py3k + else ({'iterator': 'itervalues'}, _dict_decorators()), +} + class MappedCollection(dict): """A basic dictionary-based collection class. - Extends dict with the minimal bag semantics that collection classes require. - ``set`` and ``remove`` are implemented in terms of a keying function: any - callable that takes an object and returns an object for use as a dictionary - key. + Extends dict with the minimal bag semantics that collection + classes require. ``set`` and ``remove`` are implemented in terms + of a keying function: any callable that takes an object and + returns an object for use as a dictionary key. """ def __init__(self, keyfunc): """Create a new collection with keying provided by keyfunc. - keyfunc may be any callable any callable that takes an object and - returns an object for use as a dictionary key. + keyfunc may be any callable that takes an object and returns an object + for use as a dictionary key. The keyfunc will be called every time the ORM needs to add a member by value-only (such as when loading instances from the database) or @@ -1389,14 +1493,16 @@ class MappedCollection(dict): """ self.keyfunc = keyfunc + @collection.appender + @collection.internally_instrumented def set(self, value, _sa_initiator=None): """Add an item by value, consulting the keyfunc for the key.""" key = self.keyfunc(value) self.__setitem__(key, value, _sa_initiator) - set = collection.internally_instrumented(set) - set = collection.appender(set) + @collection.remover + @collection.internally_instrumented def remove(self, value, _sa_initiator=None): """Remove an item by value, consulting the keyfunc for the key.""" @@ -1411,9 +1517,8 @@ class MappedCollection(dict): "values after flush?" % (value, self[key], key)) self.__delitem__(key, _sa_initiator) - remove = collection.internally_instrumented(remove) - remove = collection.remover(remove) + @collection.converter def _convert(self, dictlike): """Validate and convert a dict-like object into values for set()ing. @@ -1431,8 +1536,17 @@ class MappedCollection(dict): new_key = self.keyfunc(value) if incoming_key != new_key: raise TypeError( - "Found incompatible key %r for value %r; this collection's " + "Found incompatible key %r for value %r; this " + "collection's " "keying function requires a key of %r for this value." % ( - incoming_key, value, new_key)) + incoming_key, value, new_key)) yield value - _convert = collection.converter(_convert) + +# ensure instrumentation is associated with +# these built-in classes; if a user-defined class +# subclasses these and uses @internally_instrumented, +# the superclass is otherwise not instrumented. +# see [ticket:2406]. +_instrument_class(MappedCollection) +_instrument_class(InstrumentedList) +_instrument_class(InstrumentedSet) diff --git a/sqlalchemy/orm/dependency.py b/sqlalchemy/orm/dependency.py index cbbfb08..a87ec56 100644 --- a/sqlalchemy/orm/dependency.py +++ b/sqlalchemy/orm/dependency.py @@ -1,34 +1,21 @@ # orm/dependency.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """Relationship dependencies. -Bridges the ``PropertyLoader`` (i.e. a ``relationship()``) and the -``UOWTransaction`` together to allow processing of relationship()-based -dependencies at flush time. - """ -from sqlalchemy import sql, util -import sqlalchemy.exceptions as sa_exc -from sqlalchemy.orm import attributes, exc, sync -from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY +from .. import sql, util, exc as sa_exc +from . import attributes, exc, sync, unitofwork, \ + util as mapperutil +from .interfaces import ONETOMANY, MANYTOONE, MANYTOMANY -def create_dependency_processor(prop): - types = { - ONETOMANY : OneToManyDP, - MANYTOONE: ManyToOneDP, - MANYTOMANY : ManyToManyDP, - } - return types[prop.direction](prop) - class DependencyProcessor(object): - has_dependencies = True - def __init__(self, prop): self.prop = prop self.cascade = prop.cascade @@ -40,145 +27,290 @@ class DependencyProcessor(object): self.passive_deletes = prop.passive_deletes self.passive_updates = prop.passive_updates self.enable_typechecks = prop.enable_typechecks - self.key = prop.key - self.dependency_marker = MapperStub(self.parent, self.mapper, self.key) - if not self.prop.synchronize_pairs: - raise sa_exc.ArgumentError("Can't build a DependencyProcessor for relationship %s. " - "No target attributes to populate between parent and child are present" % self.prop) + if self.passive_deletes: + self._passive_delete_flag = attributes.PASSIVE_NO_INITIALIZE + else: + self._passive_delete_flag = attributes.PASSIVE_OFF + if self.passive_updates: + self._passive_update_flag = attributes.PASSIVE_NO_INITIALIZE + else: + self._passive_update_flag = attributes.PASSIVE_OFF - def _get_instrumented_attribute(self): - """Return the ``InstrumentedAttribute`` handled by this - ``DependencyProecssor``. - - """ - return self.parent.class_manager.get_impl(self.key) + self.key = prop.key + if not self.prop.synchronize_pairs: + raise sa_exc.ArgumentError( + "Can't build a DependencyProcessor for relationship %s. " + "No target attributes to populate between parent and " + "child are present" % + self.prop) + + @classmethod + def from_relationship(cls, prop): + return _direction_to_processor[prop.direction](prop) def hasparent(self, state): """return True if the given object instance has a parent, - according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``. - - """ - # TODO: use correct API for this - return self._get_instrumented_attribute().hasparent(state) + according to the ``InstrumentedAttribute`` handled by this + ``DependencyProcessor``. - def register_dependencies(self, uowcommit): - """Tell a ``UOWTransaction`` what mappers are dependent on - which, with regards to the two or three mappers handled by - this ``DependencyProcessor``. + """ + return self.parent.class_manager.get_impl(self.key).hasparent(state) + + def per_property_preprocessors(self, uow): + """establish actions and dependencies related to a flush. + + These actions will operate on all relevant states in + the aggregate. + + """ + uow.register_preprocessor(self, True) + + def per_property_flush_actions(self, uow): + after_save = unitofwork.ProcessAll(uow, self, False, True) + before_delete = unitofwork.ProcessAll(uow, self, True, True) + + parent_saves = unitofwork.SaveUpdateAll( + uow, + self.parent.primary_base_mapper + ) + child_saves = unitofwork.SaveUpdateAll( + uow, + self.mapper.primary_base_mapper + ) + + parent_deletes = unitofwork.DeleteAll( + uow, + self.parent.primary_base_mapper + ) + child_deletes = unitofwork.DeleteAll( + uow, + self.mapper.primary_base_mapper + ) + + self.per_property_dependencies(uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete + ) + + def per_state_flush_actions(self, uow, states, isdelete): + """establish actions and dependencies related to a flush. + + These actions will operate on all relevant states + individually. This occurs only if there are cycles + in the 'aggregated' version of events. """ - raise NotImplementedError() + parent_base_mapper = self.parent.primary_base_mapper + child_base_mapper = self.mapper.primary_base_mapper + child_saves = unitofwork.SaveUpdateAll(uow, child_base_mapper) + child_deletes = unitofwork.DeleteAll(uow, child_base_mapper) - def register_processors(self, uowcommit): - """Tell a ``UOWTransaction`` about this object as a processor, - which will be executed after that mapper's objects have been - saved or before they've been deleted. The process operation - manages attributes and dependent operations between two mappers. - - """ - raise NotImplementedError() - - def whose_dependent_on_who(self, state1, state2): - """Given an object pair assuming `obj2` is a child of `obj1`, - return a tuple with the dependent object second, or None if - there is no dependency. + # locate and disable the aggregate processors + # for this dependency - """ - if state1 is state2: - return None - elif self.direction == ONETOMANY: - return (state1, state2) + if isdelete: + before_delete = unitofwork.ProcessAll(uow, self, True, True) + before_delete.disabled = True else: - return (state2, state1) + after_save = unitofwork.ProcessAll(uow, self, False, True) + after_save.disabled = True - def process_dependencies(self, task, deplist, uowcommit, delete = False): - """This method is called during a flush operation to - synchronize data between a parent and child object. + # check if the "child" side is part of the cycle - It is called within the context of the various mappers and - sometimes individual objects sorted according to their - insert/update/delete order (topological sort). + if child_saves not in uow.cycles: + # based on the current dependencies we use, the saves/ + # deletes should always be in the 'cycles' collection + # together. if this changes, we will have to break up + # this method a bit more. + assert child_deletes not in uow.cycles - """ - raise NotImplementedError() + # child side is not part of the cycle, so we will link per-state + # actions to the aggregate "saves", "deletes" actions + child_actions = [ + (child_saves, False), (child_deletes, True) + ] + child_in_cycles = False + else: + child_in_cycles = True - def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): - """Used before the flushes' topological sort to traverse - through related objects and ensure every instance which will - require save/update/delete is properly added to the - UOWTransaction. + # check if the "parent" side is part of the cycle + if not isdelete: + parent_saves = unitofwork.SaveUpdateAll( + uow, + self.parent.base_mapper) + parent_deletes = before_delete = None + if parent_saves in uow.cycles: + parent_in_cycles = True + else: + parent_deletes = unitofwork.DeleteAll( + uow, + self.parent.base_mapper) + parent_saves = after_save = None + if parent_deletes in uow.cycles: + parent_in_cycles = True - """ - raise NotImplementedError() + # now create actions /dependencies for each state. + + for state in states: + # detect if there's anything changed or loaded + # by a preprocessor on this state/attribute. In the + # case of deletes we may try to load missing items here as well. + sum_ = state.manager[self.key].impl.get_all_pending( + state, state.dict, + self._passive_delete_flag + if isdelete + else attributes.PASSIVE_NO_INITIALIZE) + + if not sum_: + continue + + if isdelete: + before_delete = unitofwork.ProcessState(uow, + self, True, state) + if parent_in_cycles: + parent_deletes = unitofwork.DeleteState( + uow, + state, + parent_base_mapper) + else: + after_save = unitofwork.ProcessState(uow, self, False, state) + if parent_in_cycles: + parent_saves = unitofwork.SaveUpdateState( + uow, + state, + parent_base_mapper) + + if child_in_cycles: + child_actions = [] + for child_state, child in sum_: + if child_state not in uow.states: + child_action = (None, None) + else: + (deleted, listonly) = uow.states[child_state] + if deleted: + child_action = ( + unitofwork.DeleteState( + uow, child_state, + child_base_mapper), + True) + else: + child_action = ( + unitofwork.SaveUpdateState( + uow, child_state, + child_base_mapper), + False) + child_actions.append(child_action) + + # establish dependencies between our possibly per-state + # parent action and our possibly per-state child action. + for child_action, childisdelete in child_actions: + self.per_state_dependencies(uow, parent_saves, + parent_deletes, + child_action, + after_save, before_delete, + isdelete, childisdelete) + + def presort_deletes(self, uowcommit, states): + return False + + def presort_saves(self, uowcommit, states): + return False + + def process_deletes(self, uowcommit, states): + pass + + def process_saves(self, uowcommit, states): + pass + + def prop_has_changes(self, uowcommit, states, isdelete): + if not isdelete or self.passive_deletes: + passive = attributes.PASSIVE_NO_INITIALIZE + elif self.direction is MANYTOONE: + passive = attributes.PASSIVE_NO_FETCH_RELATED + else: + passive = attributes.PASSIVE_OFF + + for s in states: + # TODO: add a high speed method + # to InstanceState which returns: attribute + # has a non-None value, or had one + history = uowcommit.get_attribute_history( + s, + self.key, + passive) + if history and not history.empty(): + return True + else: + return states and \ + not self.prop._is_self_referential and \ + self.mapper in uowcommit.mappers def _verify_canload(self, state): - if state is not None and \ - not self.mapper._canload(state, allow_subtypes=not self.enable_typechecks): + if self.prop.uselist and state is None: + raise exc.FlushError( + "Can't flush None value found in " + "collection %s" % (self.prop, )) + elif state is not None and \ + not self.mapper._canload( + state, allow_subtypes=not self.enable_typechecks): if self.mapper._canload(state, allow_subtypes=True): - raise exc.FlushError( - "Attempting to flush an item of type %s on collection '%s', " - "which is not the expected type %s. Configure mapper '%s' to " - "load this subtype polymorphically, or set " - "enable_typechecks=False to allow subtypes. " - "Mismatched typeloading may cause bi-directional relationships " - "(backrefs) to not function properly." % - (state.class_, self.prop, self.mapper.class_, self.mapper)) + raise exc.FlushError('Attempting to flush an item of type ' + '%(x)s as a member of collection ' + '"%(y)s". Expected an object of type ' + '%(z)s or a polymorphic subclass of ' + 'this type. If %(x)s is a subclass of ' + '%(z)s, configure mapper "%(zm)s" to ' + 'load this subtype polymorphically, or ' + 'set enable_typechecks=False to allow ' + 'any subtype to be accepted for flush. ' + % { + 'x': state.class_, + 'y': self.prop, + 'z': self.mapper.class_, + 'zm': self.mapper, + }) else: raise exc.FlushError( - "Attempting to flush an item of type %s on collection '%s', " - "whose mapper does not inherit from that of %s." % - (state.class_, self.prop, self.mapper.class_)) - - def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): - """Called during a flush to synchronize primary key identifier - values between a parent/child object, as well as to an - associationrow in the case of many-to-many. - - """ + 'Attempting to flush an item of type ' + '%(x)s as a member of collection ' + '"%(y)s". Expected an object of type ' + '%(z)s or a polymorphic subclass of ' + 'this type.' % { + 'x': state.class_, + 'y': self.prop, + 'z': self.mapper.class_, + }) + + def _synchronize(self, state, child, associationrow, + clearkeys, uowcommit): raise NotImplementedError() - def _check_reverse_action(self, uowcommit, parent, child, action): - """Determine if an action has been performed by the 'reverse' property of this property. - - this is used to ensure that only one side of a bidirectional relationship - issues a certain operation for a parent/child pair. - - """ - for r in self.prop._reverse_property: - if not r.viewonly and (r._dependency_processor, action, parent, child) in uowcommit.attributes: - return True - return False - - def _performed_action(self, uowcommit, parent, child, action): - """Establish that an action has been performed for a certain parent/child pair. - - Used only for actions that are sensitive to bidirectional double-action, - i.e. manytomany, post_update. - - """ - uowcommit.attributes[(self, action, parent, child)] = True - - def _conditional_post_update(self, state, uowcommit, related): - """Execute a post_update call. + def _get_reversed_processed_set(self, uow): + if not self.prop._reverse_property: + return None - For relationships that contain the post_update flag, an additional - ``UPDATE`` statement may be associated after an ``INSERT`` or - before a ``DELETE`` in order to resolve circular row - dependencies. + process_key = tuple(sorted( + [self.key] + + [p.key for p in self.prop._reverse_property] + )) + return uow.memo( + ('reverse_key', process_key), + set + ) - This method will check for the post_update flag being set on a - particular relationship, and given a target object and list of - one or more related objects, and execute the ``UPDATE`` if the - given related object list contains ``INSERT``s or ``DELETE``s. - - """ - if state is not None and self.post_update: - for x in related: - if x is not None and not self._check_reverse_action(uowcommit, x, state, "postupdate"): - uowcommit.register_object(state, postupdate=True, post_update_cols=[r for l, r in self.prop.synchronize_pairs]) - self._performed_action(uowcommit, x, state, "postupdate") - break + def _post_update(self, state, uowcommit, related, is_m2o_delete=False): + for x in related: + if not is_m2o_delete or x is not None: + uowcommit.issue_post_update( + state, + [r for l, r in self.prop.synchronize_pairs] + ) + break def _pks_changed(self, uowcommit, state): raise NotImplementedError() @@ -186,390 +318,858 @@ class DependencyProcessor(object): def __repr__(self): return "%s(%s)" % (self.__class__.__name__, self.prop) + class OneToManyDP(DependencyProcessor): - def register_dependencies(self, uowcommit): + + def per_property_dependencies(self, uow, parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ): if self.post_update: - uowcommit.register_dependency(self.mapper, self.dependency_marker) - uowcommit.register_dependency(self.parent, self.dependency_marker) - else: - uowcommit.register_dependency(self.parent, self.mapper) + child_post_updates = unitofwork.IssuePostUpdate( + uow, + self.mapper.primary_base_mapper, + False) + child_pre_updates = unitofwork.IssuePostUpdate( + uow, + self.mapper.primary_base_mapper, + True) + + uow.dependencies.update([ + (child_saves, after_save), + (parent_saves, after_save), + (after_save, child_post_updates), + + (before_delete, child_pre_updates), + (child_pre_updates, parent_deletes), + (child_pre_updates, child_deletes), + + ]) + else: + uow.dependencies.update([ + (parent_saves, after_save), + (after_save, child_saves), + (after_save, child_deletes), + + (child_saves, parent_deletes), + (child_deletes, parent_deletes), + + (before_delete, child_saves), + (before_delete, child_deletes), + ]) + + def per_state_dependencies(self, uow, + save_parent, + delete_parent, + child_action, + after_save, before_delete, + isdelete, childisdelete): - def register_processors(self, uowcommit): if self.post_update: - uowcommit.register_processor(self.dependency_marker, self, self.parent) - else: - uowcommit.register_processor(self.parent, self, self.parent) - def process_dependencies(self, task, deplist, uowcommit, delete = False): - if delete: - # head object is being deleted, and we manage its list of child objects - # the child objects have to have their foreign key to the parent set to NULL - # this phase can be called safely for any cascade but is unnecessary if delete cascade - # is on. - if self.post_update or not self.passive_deletes == 'all': - for state in deplist: - history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) - if history: - for child in history.deleted: - if child is not None and self.hasparent(child) is False: - self._synchronize(state, child, None, True, uowcommit) - self._conditional_post_update(child, uowcommit, [state]) - if self.post_update or not self.cascade.delete: - for child in history.unchanged: - if child is not None: - self._synchronize(state, child, None, True, uowcommit) - self._conditional_post_update(child, uowcommit, [state]) - else: - for state in deplist: - history = uowcommit.get_attribute_history(state, self.key, passive=True) - if history: - for child in history.added: - self._synchronize(state, child, None, False, uowcommit) - if child is not None: - self._conditional_post_update(child, uowcommit, [state]) + child_post_updates = unitofwork.IssuePostUpdate( + uow, + self.mapper.primary_base_mapper, + False) + child_pre_updates = unitofwork.IssuePostUpdate( + uow, + self.mapper.primary_base_mapper, + True) - for child in history.deleted: - if not self.cascade.delete_orphan and not self.hasparent(child): - self._synchronize(state, child, None, True, uowcommit) - - if self._pks_changed(uowcommit, state): - for child in history.unchanged: - self._synchronize(state, child, None, False, uowcommit) - - def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): - if delete: - # head object is being deleted, and we manage its list of child objects - # the child objects have to have their foreign key to the parent set to NULL - if not self.post_update: - should_null_fks = not self.cascade.delete and not self.passive_deletes == 'all' - for state in deplist: - history = uowcommit.get_attribute_history( - state, self.key, passive=self.passive_deletes) - if history: - for child in history.deleted: - if child is not None and self.hasparent(child) is False: - if self.cascade.delete_orphan: - uowcommit.register_object(child, isdelete=True) - else: - uowcommit.register_object(child) - if should_null_fks: - for child in history.unchanged: - if child is not None: - uowcommit.register_object(child) + # TODO: this whole block is not covered + # by any tests + if not isdelete: + if childisdelete: + uow.dependencies.update([ + (child_action, after_save), + (after_save, child_post_updates), + ]) + else: + uow.dependencies.update([ + (save_parent, after_save), + (child_action, after_save), + (after_save, child_post_updates), + ]) + else: + if childisdelete: + uow.dependencies.update([ + (before_delete, child_pre_updates), + (child_pre_updates, delete_parent), + ]) + else: + uow.dependencies.update([ + (before_delete, child_pre_updates), + (child_pre_updates, delete_parent), + ]) + elif not isdelete: + uow.dependencies.update([ + (save_parent, after_save), + (after_save, child_action), + (save_parent, child_action) + ]) else: - for state in deplist: - history = uowcommit.get_attribute_history(state, self.key, passive=True) - if history: - for child in history.added: - if child is not None: - uowcommit.register_object(child) - for child in history.deleted: - if not self.cascade.delete_orphan: - uowcommit.register_object(child, isdelete=False) - elif self.hasparent(child) is False: + uow.dependencies.update([ + (before_delete, child_action), + (child_action, delete_parent) + ]) + + def presort_deletes(self, uowcommit, states): + # head object is being deleted, and we manage its list of + # child objects the child objects have to have their + # foreign key to the parent set to NULL + should_null_fks = not self.cascade.delete and \ + not self.passive_deletes == 'all' + + for state in states: + history = uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + if history: + for child in history.deleted: + if child is not None and self.hasparent(child) is False: + if self.cascade.delete_orphan: uowcommit.register_object(child, isdelete=True) - for c, m in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object( - attributes.instance_state(c), - isdelete=True) - if self._pks_changed(uowcommit, state): - if not history: - history = uowcommit.get_attribute_history( - state, self.key, passive=self.passive_updates) - if history: - for child in history.unchanged: - if child is not None: - uowcommit.register_object(child) + else: + uowcommit.register_object(child) - def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): + if should_null_fks: + for child in history.unchanged: + if child is not None: + uowcommit.register_object( + child, operation="delete", prop=self.prop) + + def presort_saves(self, uowcommit, states): + children_added = uowcommit.memo(('children_added', self), set) + + for state in states: + pks_changed = self._pks_changed(uowcommit, state) + + if not pks_changed or self.passive_updates: + passive = attributes.PASSIVE_NO_INITIALIZE + else: + passive = attributes.PASSIVE_OFF + + history = uowcommit.get_attribute_history( + state, + self.key, + passive) + if history: + for child in history.added: + if child is not None: + uowcommit.register_object(child, cancel_delete=True, + operation="add", + prop=self.prop) + + children_added.update(history.added) + + for child in history.deleted: + if not self.cascade.delete_orphan: + uowcommit.register_object(child, isdelete=False, + operation='delete', + prop=self.prop) + elif self.hasparent(child) is False: + uowcommit.register_object( + child, isdelete=True, + operation="delete", prop=self.prop) + for c, m, st_, dct_ in self.mapper.cascade_iterator( + 'delete', child): + uowcommit.register_object( + st_, + isdelete=True) + + if pks_changed: + if history: + for child in history.unchanged: + if child is not None: + uowcommit.register_object( + child, + False, + self.passive_updates, + operation="pk change", + prop=self.prop) + + def process_deletes(self, uowcommit, states): + # head object is being deleted, and we manage its list of + # child objects the child objects have to have their foreign + # key to the parent set to NULL this phase can be called + # safely for any cascade but is unnecessary if delete cascade + # is on. + + if self.post_update or not self.passive_deletes == 'all': + children_added = uowcommit.memo(('children_added', self), set) + + for state in states: + history = uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + if history: + for child in history.deleted: + if child is not None and \ + self.hasparent(child) is False: + self._synchronize( + state, + child, + None, True, + uowcommit, False) + if self.post_update and child: + self._post_update(child, uowcommit, [state]) + + if self.post_update or not self.cascade.delete: + for child in set(history.unchanged).\ + difference(children_added): + if child is not None: + self._synchronize( + state, + child, + None, True, + uowcommit, False) + if self.post_update and child: + self._post_update(child, + uowcommit, + [state]) + + # technically, we can even remove each child from the + # collection here too. but this would be a somewhat + # inconsistent behavior since it wouldn't happen + # if the old parent wasn't deleted but child was moved. + + def process_saves(self, uowcommit, states): + for state in states: + history = uowcommit.get_attribute_history( + state, + self.key, + attributes.PASSIVE_NO_INITIALIZE) + if history: + for child in history.added: + self._synchronize(state, child, None, + False, uowcommit, False) + if child is not None and self.post_update: + self._post_update(child, uowcommit, [state]) + + for child in history.deleted: + if not self.cascade.delete_orphan and \ + not self.hasparent(child): + self._synchronize(state, child, None, True, + uowcommit, False) + + if self._pks_changed(uowcommit, state): + for child in history.unchanged: + self._synchronize(state, child, None, + False, uowcommit, True) + + def _synchronize(self, state, child, + associationrow, clearkeys, uowcommit, + pks_changed): source = state dest = child - if dest is None or (not self.post_update and uowcommit.is_deleted(dest)): - return self._verify_canload(child) + if dest is None or \ + (not self.post_update and uowcommit.is_deleted(dest)): + return if clearkeys: sync.clear(dest, self.mapper, self.prop.synchronize_pairs) else: - sync.populate(source, self.parent, dest, self.mapper, - self.prop.synchronize_pairs, uowcommit, - self.passive_updates) + sync.populate(source, self.parent, dest, self.mapper, + self.prop.synchronize_pairs, uowcommit, + self.passive_updates and pks_changed) def _pks_changed(self, uowcommit, state): - return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs) + return sync.source_modified( + uowcommit, + state, + self.parent, + self.prop.synchronize_pairs) -class DetectKeySwitch(DependencyProcessor): - """a special DP that works for many-to-one relationships, fires off for - child items who have changed their referenced key.""" - - has_dependencies = False - - def register_dependencies(self, uowcommit): - pass - - def register_processors(self, uowcommit): - uowcommit.register_processor(self.parent, self, self.mapper) - - def preprocess_dependencies(self, task, deplist, uowcommit, delete=False): - # for non-passive updates, register in the preprocess stage - # so that mapper save_obj() gets a hold of changes - if not delete and not self.passive_updates: - self._process_key_switches(deplist, uowcommit) - - def process_dependencies(self, task, deplist, uowcommit, delete=False): - # for passive updates, register objects in the process stage - # so that we avoid ManyToOneDP's registering the object without - # the listonly flag in its own preprocess stage (results in UPDATE) - # statements being emitted - if not delete and self.passive_updates: - self._process_key_switches(deplist, uowcommit) - - def _process_key_switches(self, deplist, uowcommit): - switchers = set(s for s in deplist if self._pks_changed(uowcommit, s)) - if switchers: - # yes, we're doing a linear search right now through the UOW. only - # takes effect when primary key values have actually changed. - # a possible optimization might be to enhance the "hasparents" capability of - # attributes to actually store all parent references, but this introduces - # more complicated attribute accounting. - for s in [elem for elem in uowcommit.session.identity_map.all_states() - if issubclass(elem.class_, self.parent.class_) and - self.key in elem.dict and - elem.dict[self.key] is not None and - attributes.instance_state(elem.dict[self.key]) in switchers - ]: - uowcommit.register_object(s) - sync.populate( - attributes.instance_state(s.dict[self.key]), - self.mapper, s, self.parent, self.prop.synchronize_pairs, - uowcommit, self.passive_updates) - - def _pks_changed(self, uowcommit, state): - return sync.source_modified(uowcommit, state, self.mapper, self.prop.synchronize_pairs) class ManyToOneDP(DependencyProcessor): def __init__(self, prop): DependencyProcessor.__init__(self, prop) self.mapper._dependency_processors.append(DetectKeySwitch(prop)) - def register_dependencies(self, uowcommit): - if self.post_update: - uowcommit.register_dependency(self.mapper, self.dependency_marker) - uowcommit.register_dependency(self.parent, self.dependency_marker) - else: - uowcommit.register_dependency(self.mapper, self.parent) - - def register_processors(self, uowcommit): - if self.post_update: - uowcommit.register_processor(self.dependency_marker, self, self.parent) - else: - uowcommit.register_processor(self.mapper, self, self.parent) + def per_property_dependencies(self, uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete): - def process_dependencies(self, task, deplist, uowcommit, delete=False): - if delete: - if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all': - # post_update means we have to update our row to not reference the child object - # before we can DELETE the row - for state in deplist: - self._synchronize(state, None, None, True, uowcommit) - history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) - if history: - self._conditional_post_update(state, uowcommit, history.sum()) + if self.post_update: + parent_post_updates = unitofwork.IssuePostUpdate( + uow, + self.parent.primary_base_mapper, + False) + parent_pre_updates = unitofwork.IssuePostUpdate( + uow, + self.parent.primary_base_mapper, + True) + + uow.dependencies.update([ + (child_saves, after_save), + (parent_saves, after_save), + (after_save, parent_post_updates), + + (after_save, parent_pre_updates), + (before_delete, parent_pre_updates), + + (parent_pre_updates, child_deletes), + ]) else: - for state in deplist: - history = uowcommit.get_attribute_history(state, self.key, passive=True) + uow.dependencies.update([ + (child_saves, after_save), + (after_save, parent_saves), + (parent_saves, child_deletes), + (parent_deletes, child_deletes) + ]) + + def per_state_dependencies(self, uow, + save_parent, + delete_parent, + child_action, + after_save, before_delete, + isdelete, childisdelete): + + if self.post_update: + + if not isdelete: + parent_post_updates = unitofwork.IssuePostUpdate( + uow, + self.parent.primary_base_mapper, + False) + if childisdelete: + uow.dependencies.update([ + (after_save, parent_post_updates), + (parent_post_updates, child_action) + ]) + else: + uow.dependencies.update([ + (save_parent, after_save), + (child_action, after_save), + + (after_save, parent_post_updates) + ]) + else: + parent_pre_updates = unitofwork.IssuePostUpdate( + uow, + self.parent.primary_base_mapper, + True) + + uow.dependencies.update([ + (before_delete, parent_pre_updates), + (parent_pre_updates, delete_parent), + (parent_pre_updates, child_action) + ]) + + elif not isdelete: + if not childisdelete: + uow.dependencies.update([ + (child_action, after_save), + (after_save, save_parent), + ]) + else: + uow.dependencies.update([ + (after_save, save_parent), + ]) + + else: + if childisdelete: + uow.dependencies.update([ + (delete_parent, child_action) + ]) + + def presort_deletes(self, uowcommit, states): + if self.cascade.delete or self.cascade.delete_orphan: + for state in states: + history = uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) if history: + if self.cascade.delete_orphan: + todelete = history.sum() + else: + todelete = history.non_deleted() + for child in todelete: + if child is None: + continue + uowcommit.register_object( + child, isdelete=True, + operation="delete", prop=self.prop) + t = self.mapper.cascade_iterator('delete', child) + for c, m, st_, dct_ in t: + uowcommit.register_object( + st_, isdelete=True) + + def presort_saves(self, uowcommit, states): + for state in states: + uowcommit.register_object(state, operation="add", prop=self.prop) + if self.cascade.delete_orphan: + history = uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + if history: + for child in history.deleted: + if self.hasparent(child) is False: + uowcommit.register_object( + child, isdelete=True, + operation="delete", prop=self.prop) + + t = self.mapper.cascade_iterator('delete', child) + for c, m, st_, dct_ in t: + uowcommit.register_object(st_, isdelete=True) + + def process_deletes(self, uowcommit, states): + if self.post_update and \ + not self.cascade.delete_orphan and \ + not self.passive_deletes == 'all': + + # post_update means we have to update our + # row to not reference the child object + # before we can DELETE the row + for state in states: + self._synchronize(state, None, None, True, uowcommit) + if state and self.post_update: + history = uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + if history: + self._post_update( + state, uowcommit, history.sum(), + is_m2o_delete=True) + + def process_saves(self, uowcommit, states): + for state in states: + history = uowcommit.get_attribute_history( + state, + self.key, + attributes.PASSIVE_NO_INITIALIZE) + if history: + if history.added: for child in history.added: - self._synchronize(state, child, None, False, uowcommit) - self._conditional_post_update(state, uowcommit, history.sum()) + self._synchronize(state, child, None, False, + uowcommit, "add") + if self.post_update: + self._post_update(state, uowcommit, history.sum()) - def preprocess_dependencies(self, task, deplist, uowcommit, delete=False): - if self.post_update: + def _synchronize(self, state, child, associationrow, + clearkeys, uowcommit, operation=None): + if state is None or \ + (not self.post_update and uowcommit.is_deleted(state)): return - if delete: - if self.cascade.delete or self.cascade.delete_orphan: - for state in deplist: - history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) - if history: - if self.cascade.delete_orphan: - todelete = history.sum() - else: - todelete = history.non_deleted() - for child in todelete: - if child is None: - continue - uowcommit.register_object(child, isdelete=True) - for c, m in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object( - attributes.instance_state(c), isdelete=True) - else: - for state in deplist: - uowcommit.register_object(state) - if self.cascade.delete_orphan: - history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) - if history: - for child in history.deleted: - if self.hasparent(child) is False: - uowcommit.register_object(child, isdelete=True) - for c, m in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object( - attributes.instance_state(c), - isdelete=True) - - def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): - if state is None or (not self.post_update and uowcommit.is_deleted(state)): + if operation is not None and \ + child is not None and \ + not uowcommit.session._contains_state(child): + util.warn( + "Object of type %s not in session, %s " + "operation along '%s' won't proceed" % + (mapperutil.state_class_str(child), operation, self.prop)) return if clearkeys or child is None: sync.clear(state, self.parent, self.prop.synchronize_pairs) else: self._verify_canload(child) - sync.populate(child, self.mapper, state, - self.parent, self.prop.synchronize_pairs, uowcommit, - self.passive_updates - ) + sync.populate(child, self.mapper, state, + self.parent, + self.prop.synchronize_pairs, + uowcommit, + False) + + +class DetectKeySwitch(DependencyProcessor): + """For many-to-one relationships with no one-to-many backref, + searches for parents through the unit of work when a primary + key has changed and updates them. + + Theoretically, this approach could be expanded to support transparent + deletion of objects referenced via many-to-one as well, although + the current attribute system doesn't do enough bookkeeping for this + to be efficient. + + """ + + def per_property_preprocessors(self, uow): + if self.prop._reverse_property: + if self.passive_updates: + return + else: + if False in (prop.passive_updates for + prop in self.prop._reverse_property): + return + + uow.register_preprocessor(self, False) + + def per_property_flush_actions(self, uow): + parent_saves = unitofwork.SaveUpdateAll( + uow, + self.parent.base_mapper) + after_save = unitofwork.ProcessAll(uow, self, False, False) + uow.dependencies.update([ + (parent_saves, after_save) + ]) + + def per_state_flush_actions(self, uow, states, isdelete): + pass + + def presort_deletes(self, uowcommit, states): + pass + + def presort_saves(self, uow, states): + if not self.passive_updates: + # for non-passive updates, register in the preprocess stage + # so that mapper save_obj() gets a hold of changes + self._process_key_switches(states, uow) + + def prop_has_changes(self, uow, states, isdelete): + if not isdelete and self.passive_updates: + d = self._key_switchers(uow, states) + return bool(d) + + return False + + def process_deletes(self, uowcommit, states): + assert False + + def process_saves(self, uowcommit, states): + # for passive updates, register objects in the process stage + # so that we avoid ManyToOneDP's registering the object without + # the listonly flag in its own preprocess stage (results in UPDATE) + # statements being emitted + assert self.passive_updates + self._process_key_switches(states, uowcommit) + + def _key_switchers(self, uow, states): + switched, notswitched = uow.memo( + ('pk_switchers', self), + lambda: (set(), set()) + ) + + allstates = switched.union(notswitched) + for s in states: + if s not in allstates: + if self._pks_changed(uow, s): + switched.add(s) + else: + notswitched.add(s) + return switched + + def _process_key_switches(self, deplist, uowcommit): + switchers = self._key_switchers(uowcommit, deplist) + if switchers: + # if primary key values have actually changed somewhere, perform + # a linear search through the UOW in search of a parent. + for state in uowcommit.session.identity_map.all_states(): + if not issubclass(state.class_, self.parent.class_): + continue + dict_ = state.dict + related = state.get_impl(self.key).get( + state, dict_, passive=self._passive_update_flag) + if related is not attributes.PASSIVE_NO_RESULT and \ + related is not None: + related_state = attributes.instance_state(dict_[self.key]) + if related_state in switchers: + uowcommit.register_object(state, + False, + self.passive_updates) + sync.populate( + related_state, + self.mapper, state, + self.parent, self.prop.synchronize_pairs, + uowcommit, self.passive_updates) + + def _pks_changed(self, uowcommit, state): + return bool(state.key) and sync.source_modified( + uowcommit, state, self.mapper, self.prop.synchronize_pairs) + class ManyToManyDP(DependencyProcessor): - def register_dependencies(self, uowcommit): - # many-to-many. create a "Stub" mapper to represent the - # "middle table" in the relationship. This stub mapper doesnt save - # or delete any objects, but just marks a dependency on the two - # related mappers. its dependency processor then populates the - # association table. - uowcommit.register_dependency(self.parent, self.dependency_marker) - uowcommit.register_dependency(self.mapper, self.dependency_marker) + def per_property_dependencies(self, uow, parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete + ): - def register_processors(self, uowcommit): - uowcommit.register_processor(self.dependency_marker, self, self.parent) - - def process_dependencies(self, task, deplist, uowcommit, delete = False): - connection = uowcommit.transaction.connection(self.mapper) + uow.dependencies.update([ + (parent_saves, after_save), + (child_saves, after_save), + (after_save, child_deletes), + + # a rowswitch on the parent from deleted to saved + # can make this one occur, as the "save" may remove + # an element from the + # "deleted" list before we have a chance to + # process its child rows + (before_delete, parent_saves), + + (before_delete, parent_deletes), + (before_delete, child_deletes), + (before_delete, child_saves), + ]) + + def per_state_dependencies(self, uow, + save_parent, + delete_parent, + child_action, + after_save, before_delete, + isdelete, childisdelete): + if not isdelete: + if childisdelete: + uow.dependencies.update([ + (save_parent, after_save), + (after_save, child_action), + ]) + else: + uow.dependencies.update([ + (save_parent, after_save), + (child_action, after_save), + ]) + else: + uow.dependencies.update([ + (before_delete, child_action), + (before_delete, delete_parent) + ]) + + def presort_deletes(self, uowcommit, states): + # TODO: no tests fail if this whole + # thing is removed !!!! + if not self.passive_deletes: + # if no passive deletes, load history on + # the collection, so that prop_has_changes() + # returns True + for state in states: + uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + + def presort_saves(self, uowcommit, states): + if not self.passive_updates: + # if no passive updates, load history on + # each collection where parent has changed PK, + # so that prop_has_changes() returns True + for state in states: + if self._pks_changed(uowcommit, state): + history = uowcommit.get_attribute_history( + state, + self.key, + attributes.PASSIVE_OFF) + + if not self.cascade.delete_orphan: + return + + # check for child items removed from the collection + # if delete_orphan check is turned on. + for state in states: + history = uowcommit.get_attribute_history( + state, + self.key, + attributes.PASSIVE_NO_INITIALIZE) + if history: + for child in history.deleted: + if self.hasparent(child) is False: + uowcommit.register_object( + child, isdelete=True, + operation="delete", prop=self.prop) + for c, m, st_, dct_ in self.mapper.cascade_iterator( + 'delete', + child): + uowcommit.register_object( + st_, isdelete=True) + + def process_deletes(self, uowcommit, states): secondary_delete = [] secondary_insert = [] secondary_update = [] - if delete: - for state in deplist: - history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) - if history: - for child in history.non_added(): - if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"): - continue - associationrow = {} - self._synchronize(state, child, associationrow, False, uowcommit) - secondary_delete.append(associationrow) - self._performed_action(uowcommit, state, child, "manytomany") - else: - for state in deplist: - history = uowcommit.get_attribute_history(state, self.key) - if history: - for child in history.added: - if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"): - continue - associationrow = {} - self._synchronize(state, child, associationrow, False, uowcommit) - self._performed_action(uowcommit, state, child, "manytomany") - secondary_insert.append(associationrow) - for child in history.deleted: - if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"): - continue - associationrow = {} - self._synchronize(state, child, associationrow, False, uowcommit) - self._performed_action(uowcommit, state, child, "manytomany") - secondary_delete.append(associationrow) + processed = self._get_reversed_processed_set(uowcommit) + tmp = set() + for state in states: + # this history should be cached already, as + # we loaded it in preprocess_deletes + history = uowcommit.get_attribute_history( + state, + self.key, + self._passive_delete_flag) + if history: + for child in history.non_added(): + if child is None or \ + (processed is not None and + (state, child) in processed): + continue + associationrow = {} + if not self._synchronize( + state, + child, + associationrow, + False, uowcommit, "delete"): + continue + secondary_delete.append(associationrow) + + tmp.update((c, state) for c in history.non_added()) + + if processed is not None: + processed.update(tmp) + + self._run_crud(uowcommit, secondary_insert, + secondary_update, secondary_delete) + + def process_saves(self, uowcommit, states): + secondary_delete = [] + secondary_insert = [] + secondary_update = [] + + processed = self._get_reversed_processed_set(uowcommit) + tmp = set() + + for state in states: + need_cascade_pks = not self.passive_updates and \ + self._pks_changed(uowcommit, state) + if need_cascade_pks: + passive = attributes.PASSIVE_OFF + else: + passive = attributes.PASSIVE_NO_INITIALIZE + history = uowcommit.get_attribute_history(state, self.key, + passive) + if history: + for child in history.added: + if (processed is not None and + (state, child) in processed): + continue + associationrow = {} + if not self._synchronize(state, + child, + associationrow, + False, uowcommit, "add"): + continue + secondary_insert.append(associationrow) + for child in history.deleted: + if (processed is not None and + (state, child) in processed): + continue + associationrow = {} + if not self._synchronize(state, + child, + associationrow, + False, uowcommit, "delete"): + continue + secondary_delete.append(associationrow) + + tmp.update((c, state) + for c in history.added + history.deleted) + + if need_cascade_pks: - if not self.passive_updates and self._pks_changed(uowcommit, state): - if not history: - history = uowcommit.get_attribute_history(state, self.key, passive=False) - for child in history.unchanged: associationrow = {} - sync.update(state, self.parent, associationrow, "old_", self.prop.synchronize_pairs) - sync.update(child, self.mapper, associationrow, "old_", self.prop.secondary_synchronize_pairs) + sync.update(state, + self.parent, + associationrow, + "old_", + self.prop.synchronize_pairs) + sync.update(child, + self.mapper, + associationrow, + "old_", + self.prop.secondary_synchronize_pairs) - #self.syncrules.update(associationrow, state, child, "old_") secondary_update.append(associationrow) + if processed is not None: + processed.update(tmp) + + self._run_crud(uowcommit, secondary_insert, + secondary_update, secondary_delete) + + def _run_crud(self, uowcommit, secondary_insert, + secondary_update, secondary_delete): + connection = uowcommit.transaction.connection(self.mapper) + if secondary_delete: + associationrow = secondary_delete[0] statement = self.secondary.delete(sql.and_(*[ - c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow - ])) + c == sql.bindparam(c.key, type_=c.type) + for c in self.secondary.c + if c.key in associationrow + ])) result = connection.execute(statement, secondary_delete) - if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_delete): - raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of " - "secondary table rows deleted from table '%s': %d" % - (result.rowcount, self.secondary.description, len(secondary_delete))) + + if result.supports_sane_multi_rowcount() and \ + result.rowcount != len(secondary_delete): + raise exc.StaleDataError( + "DELETE statement on table '%s' expected to delete " + "%d row(s); Only %d were matched." % + (self.secondary.description, len(secondary_delete), + result.rowcount) + ) if secondary_update: + associationrow = secondary_update[0] statement = self.secondary.update(sql.and_(*[ - c == sql.bindparam("old_" + c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow - ])) + c == sql.bindparam("old_" + c.key, type_=c.type) + for c in self.secondary.c + if c.key in associationrow + ])) result = connection.execute(statement, secondary_update) - if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_update): - raise exc.ConcurrentModificationError("Updated rowcount %d does not match number of " - "secondary table rows updated from table '%s': %d" % - (result.rowcount, self.secondary.description, len(secondary_update))) + + if result.supports_sane_multi_rowcount() and \ + result.rowcount != len(secondary_update): + raise exc.StaleDataError( + "UPDATE statement on table '%s' expected to update " + "%d row(s); Only %d were matched." % + (self.secondary.description, len(secondary_update), + result.rowcount) + ) if secondary_insert: statement = self.secondary.insert() connection.execute(statement, secondary_insert) - def preprocess_dependencies(self, task, deplist, uowcommit, delete = False): - if not delete: - for state in deplist: - history = uowcommit.get_attribute_history(state, self.key, passive=True) - if history: - for child in history.deleted: - if self.cascade.delete_orphan and self.hasparent(child) is False: - uowcommit.register_object(child, isdelete=True) - for c, m in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object( - attributes.instance_state(c), isdelete=True) + def _synchronize(self, state, child, associationrow, + clearkeys, uowcommit, operation): - def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): - if associationrow is None: - return + # this checks for None if uselist=True self._verify_canload(child) - - sync.populate_dict(state, self.parent, associationrow, - self.prop.synchronize_pairs) + + # but if uselist=False we get here. If child is None, + # no association row can be generated, so return. + if child is None: + return False + + if child is not None and not uowcommit.session._contains_state(child): + if not child.deleted: + util.warn( + "Object of type %s not in session, %s " + "operation along '%s' won't proceed" % + (mapperutil.state_class_str(child), operation, self.prop)) + return False + + sync.populate_dict(state, self.parent, associationrow, + self.prop.synchronize_pairs) sync.populate_dict(child, self.mapper, associationrow, - self.prop.secondary_synchronize_pairs) + self.prop.secondary_synchronize_pairs) + + return True def _pks_changed(self, uowcommit, state): - return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs) + return sync.source_modified( + uowcommit, + state, + self.parent, + self.prop.synchronize_pairs) -class MapperStub(object): - """Represent a many-to-many dependency within a flush - context. - - The UOWTransaction corresponds dependencies to mappers. - MapperStub takes the place of the "association table" - so that a depedendency can be corresponded to it. - - """ - - def __init__(self, parent, mapper, key): - self.mapper = mapper - self.base_mapper = self - self.class_ = mapper.class_ - self._inheriting_mappers = [] - - def polymorphic_iterator(self): - return iter((self,)) - - def _register_dependencies(self, uowcommit): - pass - - def _register_procesors(self, uowcommit): - pass - - def _save_obj(self, *args, **kwargs): - pass - - def _delete_obj(self, *args, **kwargs): - pass - - def primary_mapper(self): - return self +_direction_to_processor = { + ONETOMANY: OneToManyDP, + MANYTOONE: ManyToOneDP, + MANYTOMANY: ManyToManyDP, +} diff --git a/sqlalchemy/orm/dynamic.py b/sqlalchemy/orm/dynamic.py index d796040..9f99740 100644 --- a/sqlalchemy/orm/dynamic.py +++ b/sqlalchemy/orm/dynamic.py @@ -1,5 +1,6 @@ -# dynamic.py -# Copyright (C) the SQLAlchemy authors and contributors +# orm/dynamic.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -11,42 +12,47 @@ basic add/delete mutation. """ -from sqlalchemy import log, util -from sqlalchemy import exc as sa_exc -from sqlalchemy.orm import exc as sa_exc -from sqlalchemy.sql import operators -from sqlalchemy.orm import ( - attributes, object_session, util as mapperutil, strategies, object_mapper - ) -from sqlalchemy.orm.query import Query -from sqlalchemy.orm.util import _state_has_identity, has_identity -from sqlalchemy.orm import attributes, collections +from .. import log, util, exc +from ..sql import operators +from . import ( + attributes, object_session, util as orm_util, strategies, + object_mapper, exc as orm_exc, properties +) +from .query import Query + +@log.class_logger +@properties.RelationshipProperty.strategy_for(lazy="dynamic") class DynaLoader(strategies.AbstractRelationshipLoader): def init_class_attribute(self, mapper): self.is_class_level = True - - strategies._register_attribute(self, + if not self.uselist: + raise exc.InvalidRequestError( + "On relationship %s, 'dynamic' loaders cannot be used with " + "many-to-one/one-to-one relationships and/or " + "uselist=False." % self.parent_property) + strategies._register_attribute( + self.parent_property, mapper, useobject=True, impl_class=DynamicAttributeImpl, target_mapper=self.parent_property.mapper, order_by=self.parent_property.order_by, - query_class=self.parent_property.query_class + query_class=self.parent_property.query_class, ) - def create_row_processor(self, selectcontext, path, mapper, row, adapter): - return (None, None) - -log.class_logger(DynaLoader) class DynamicAttributeImpl(attributes.AttributeImpl): uses_objects = True accepts_scalar_loader = False + supports_population = False + collection = False def __init__(self, class_, key, typecallable, - target_mapper, order_by, query_class=None, **kwargs): - super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, **kwargs) + dispatch, + target_mapper, order_by, query_class=None, **kw): + super(DynamicAttributeImpl, self).\ + __init__(class_, key, typecallable, dispatch, **kw) self.target_mapper = target_mapper self.order_by = order_by if not query_class: @@ -56,178 +62,204 @@ class DynamicAttributeImpl(attributes.AttributeImpl): else: self.query_class = mixin_user_query(query_class) - def get(self, state, dict_, passive=False): - if passive: - return self._get_collection_history(state, passive=True).added_items + def get(self, state, dict_, passive=attributes.PASSIVE_OFF): + if not passive & attributes.SQL_OK: + return self._get_collection_history( + state, attributes.PASSIVE_NO_INITIALIZE).added_items else: return self.query_class(self, state) - def get_collection(self, state, dict_, user_data=None, passive=True): - if passive: - return self._get_collection_history(state, passive=passive).added_items + def get_collection(self, state, dict_, user_data=None, + passive=attributes.PASSIVE_NO_INITIALIZE): + if not passive & attributes.SQL_OK: + return self._get_collection_history(state, + passive).added_items else: - history = self._get_collection_history(state, passive=passive) - return history.added_items + history.unchanged_items + history = self._get_collection_history(state, passive) + return history.added_plus_unchanged - def fire_append_event(self, state, dict_, value, initiator): - collection_history = self._modified_event(state, dict_) - collection_history.added_items.append(value) + @util.memoized_property + def _append_token(self): + return attributes.Event(self, attributes.OP_APPEND) - for ext in self.extensions: - ext.append(state, value, initiator or self) + @util.memoized_property + def _remove_token(self): + return attributes.Event(self, attributes.OP_REMOVE) + + def fire_append_event(self, state, dict_, value, initiator, + collection_history=None): + if collection_history is None: + collection_history = self._modified_event(state, dict_) + + collection_history.add_added(value) + + for fn in self.dispatch.append: + value = fn(state, value, initiator or self._append_token) if self.trackparent and value is not None: - self.sethasparent(attributes.instance_state(value), True) + self.sethasparent(attributes.instance_state(value), state, True) - def fire_remove_event(self, state, dict_, value, initiator): - collection_history = self._modified_event(state, dict_) - collection_history.deleted_items.append(value) + def fire_remove_event(self, state, dict_, value, initiator, + collection_history=None): + if collection_history is None: + collection_history = self._modified_event(state, dict_) + + collection_history.add_removed(value) if self.trackparent and value is not None: - self.sethasparent(attributes.instance_state(value), False) + self.sethasparent(attributes.instance_state(value), state, False) - for ext in self.extensions: - ext.remove(state, value, initiator or self) + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token) def _modified_event(self, state, dict_): if self.key not in state.committed_state: state.committed_state[self.key] = CollectionHistory(self, state) - state.modified_event(dict_, - self, - False, - attributes.NEVER_SET, - passive=attributes.PASSIVE_NO_INITIALIZE) + state._modified_event(dict_, + self, + attributes.NEVER_SET) - # this is a hack to allow the _base.ComparableEntity fixture + # this is a hack to allow the fixtures.ComparableEntity fixture # to work dict_[self.key] = True return state.committed_state[self.key] - def set(self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF): - if initiator is self: + def set(self, state, dict_, value, initiator=None, + passive=attributes.PASSIVE_OFF, + check_old=None, pop=False, _adapt=True): + if initiator and initiator.parent_token is self.parent_token: return - self._set_iterable(state, dict_, value) + if pop and value is None: + return - def _set_iterable(self, state, dict_, iterable, adapter=None): + iterable = value + new_values = list(iterable) + if state.has_identity: + old_collection = util.IdentitySet(self.get(state, dict_)) collection_history = self._modified_event(state, dict_) - new_values = list(iterable) - - if _state_has_identity(state): - old_collection = list(self.get(state, dict_)) + if not state.has_identity: + old_collection = collection_history.added_items else: - old_collection = [] + old_collection = old_collection.union( + collection_history.added_items) - collections.bulk_replace(new_values, DynCollectionAdapter(self, state, old_collection), DynCollectionAdapter(self, state, new_values)) + idset = util.IdentitySet + constants = old_collection.intersection(new_values) + additions = idset(new_values).difference(constants) + removals = old_collection.difference(constants) + + for member in new_values: + if member in additions: + self.fire_append_event(state, dict_, member, None, + collection_history=collection_history) + + for member in removals: + self.fire_remove_event(state, dict_, member, None, + collection_history=collection_history) def delete(self, *args, **kwargs): raise NotImplementedError() - def get_history(self, state, dict_, passive=False): - c = self._get_collection_history(state, passive) - return attributes.History(c.added_items, c.unchanged_items, c.deleted_items) + def set_committed_value(self, state, dict_, value): + raise NotImplementedError("Dynamic attributes don't support " + "collection population.") - def _get_collection_history(self, state, passive=False): + def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF): + c = self._get_collection_history(state, passive) + return c.as_history() + + def get_all_pending(self, state, dict_, + passive=attributes.PASSIVE_NO_INITIALIZE): + c = self._get_collection_history( + state, passive) + return [ + (attributes.instance_state(x), x) + for x in + c.all_items + ] + + def _get_collection_history(self, state, passive=attributes.PASSIVE_OFF): if self.key in state.committed_state: c = state.committed_state[self.key] else: c = CollectionHistory(self, state) - if not passive: + if state.has_identity and (passive & attributes.INIT_OK): return CollectionHistory(self, state, apply_to=c) else: return c - def append(self, state, dict_, value, initiator, passive=False): + def append(self, state, dict_, value, initiator, + passive=attributes.PASSIVE_OFF): if initiator is not self: self.fire_append_event(state, dict_, value, initiator) - def remove(self, state, dict_, value, initiator, passive=False): + def remove(self, state, dict_, value, initiator, + passive=attributes.PASSIVE_OFF): if initiator is not self: self.fire_remove_event(state, dict_, value, initiator) -class DynCollectionAdapter(object): - """the dynamic analogue to orm.collections.CollectionAdapter""" + def pop(self, state, dict_, value, initiator, + passive=attributes.PASSIVE_OFF): + self.remove(state, dict_, value, initiator, passive=passive) - def __init__(self, attr, owner_state, data): - self.attr = attr - self.state = owner_state - self.data = data - - def __iter__(self): - return iter(self.data) - - def append_with_event(self, item, initiator=None): - self.attr.append(self.state, self.state.dict, item, initiator) - - def remove_with_event(self, item, initiator=None): - self.attr.remove(self.state, self.state.dict, item, initiator) - - def append_without_event(self, item): - pass - - def remove_without_event(self, item): - pass class AppenderMixin(object): query_class = None def __init__(self, attr, state): - Query.__init__(self, attr.target_mapper, None) + super(AppenderMixin, self).__init__(attr.target_mapper, None) self.instance = instance = state.obj() self.attr = attr mapper = object_mapper(instance) - prop = mapper.get_property(self.attr.key, resolve_synonyms=True) - self._criterion = prop.compare( - operators.eq, - instance, - value_is_parent=True, - alias_secondary=False) + prop = mapper._props[self.attr.key] + self._criterion = prop._with_parent( + instance, + alias_secondary=False) if self.attr.order_by: self._order_by = self.attr.order_by - def __session(self): + def session(self): sess = object_session(self.instance) - if sess is not None and self.autoflush and sess.autoflush and self.instance in sess: + if sess is not None and self.autoflush and sess.autoflush \ + and self.instance in sess: sess.flush() - if not has_identity(self.instance): + if not orm_util.has_identity(self.instance): return None else: return sess - - def session(self): - return self.__session() - session = property(session, lambda s, x:None) + session = property(session, lambda s, x: None) def __iter__(self): - sess = self.__session() + sess = self.session if sess is None: return iter(self.attr._get_collection_history( attributes.instance_state(self.instance), - passive=True).added_items) + attributes.PASSIVE_NO_INITIALIZE).added_items) else: return iter(self._clone(sess)) def __getitem__(self, index): - sess = self.__session() + sess = self.session if sess is None: return self.attr._get_collection_history( attributes.instance_state(self.instance), - passive=True).added_items.__getitem__(index) + attributes.PASSIVE_NO_INITIALIZE).indexed(index) else: return self._clone(sess).__getitem__(index) def count(self): - sess = self.__session() + sess = self.session if sess is None: return len(self.attr._get_collection_history( attributes.instance_state(self.instance), - passive=True).added_items) + attributes.PASSIVE_NO_INITIALIZE).added_items) else: return self._clone(sess).count() @@ -243,26 +275,32 @@ class AppenderMixin(object): "Parent instance %s is not bound to a Session, and no " "contextual session is established; lazy load operation " "of attribute '%s' cannot proceed" % ( - mapperutil.instance_str(instance), self.attr.key)) + orm_util.instance_str(instance), self.attr.key)) if self.query_class: query = self.query_class(self.attr.target_mapper, session=sess) else: query = sess.query(self.attr.target_mapper) - + query._criterion = self._criterion query._order_by = self._order_by - + return query + def extend(self, iterator): + for item in iterator: + self.attr.append( + attributes.instance_state(self.instance), + attributes.instance_dict(self.instance), item, None) + def append(self, item): self.attr.append( - attributes.instance_state(self.instance), + attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None) def remove(self, item): self.attr.remove( - attributes.instance_state(self.instance), + attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None) @@ -275,19 +313,55 @@ def mixin_user_query(cls): name = 'Appender' + cls.__name__ return type(name, (AppenderMixin, cls), {'query_class': cls}) + class CollectionHistory(object): """Overrides AttributeHistory to receive append/remove events directly.""" def __init__(self, attr, state, apply_to=None): if apply_to: - deleted = util.IdentitySet(apply_to.deleted_items) - added = apply_to.added_items coll = AppenderQuery(attr, state).autoflush(False) - self.unchanged_items = [o for o in util.IdentitySet(coll) if o not in deleted] + self.unchanged_items = util.OrderedIdentitySet(coll) self.added_items = apply_to.added_items self.deleted_items = apply_to.deleted_items + self._reconcile_collection = True else: - self.deleted_items = [] - self.added_items = [] - self.unchanged_items = [] + self.deleted_items = util.OrderedIdentitySet() + self.added_items = util.OrderedIdentitySet() + self.unchanged_items = util.OrderedIdentitySet() + self._reconcile_collection = False + @property + def added_plus_unchanged(self): + return list(self.added_items.union(self.unchanged_items)) + + @property + def all_items(self): + return list(self.added_items.union( + self.unchanged_items).union(self.deleted_items)) + + def as_history(self): + if self._reconcile_collection: + added = self.added_items.difference(self.unchanged_items) + deleted = self.deleted_items.intersection(self.unchanged_items) + unchanged = self.unchanged_items.difference(deleted) + else: + added, unchanged, deleted = self.added_items,\ + self.unchanged_items,\ + self.deleted_items + return attributes.History( + list(added), + list(unchanged), + list(deleted), + ) + + def indexed(self, index): + return list(self.added_items)[index] + + def add_added(self, value): + self.added_items.add(value) + + def add_removed(self, value): + if value in self.added_items: + self.added_items.remove(value) + else: + self.deleted_items.add(value) diff --git a/sqlalchemy/orm/evaluator.py b/sqlalchemy/orm/evaluator.py index 3ee7078..95a9e9b 100644 --- a/sqlalchemy/orm/evaluator.py +++ b/sqlalchemy/orm/evaluator.py @@ -1,17 +1,21 @@ +# orm/evaluator.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + import operator -from sqlalchemy.sql import operators, functions -from sqlalchemy.sql import expression as sql +from ..sql import operators class UnevaluatableError(Exception): pass _straight_ops = set(getattr(operators, op) - for op in ('add', 'mul', 'sub', - # Py2K - 'div', - # end Py2K - 'mod', 'truediv', + for op in ('add', 'mul', 'sub', + 'div', + 'mod', 'truediv', 'lt', 'le', 'ne', 'gt', 'ge', 'eq')) @@ -20,11 +24,16 @@ _notimplemented_ops = set(getattr(operators, op) 'notilike_op', 'between_op', 'in_op', 'notin_op', 'endswith_op', 'concat_op')) + class EvaluatorCompiler(object): + def __init__(self, target_cls=None): + self.target_cls = target_cls + def process(self, clause): meth = getattr(self, "visit_%s" % clause.__visit_name__, None) if not meth: - raise UnevaluatableError("Cannot evaluate %s" % type(clause).__name__) + raise UnevaluatableError( + "Cannot evaluate %s" % type(clause).__name__) return meth(clause) def visit_grouping(self, clause): @@ -33,16 +42,30 @@ class EvaluatorCompiler(object): def visit_null(self, clause): return lambda obj: None + def visit_false(self, clause): + return lambda obj: False + + def visit_true(self, clause): + return lambda obj: True + def visit_column(self, clause): if 'parentmapper' in clause._annotations: - key = clause._annotations['parentmapper']._get_col_to_prop(clause).key + parentmapper = clause._annotations['parentmapper'] + if self.target_cls and not issubclass( + self.target_cls, parentmapper.class_): + raise UnevaluatableError( + "Can't evaluate criteria against alternate class %s" % + parentmapper.class_ + ) + key = parentmapper._columntoproperty[clause].key else: key = clause.key + get_corresponding_attr = operator.attrgetter(key) return lambda obj: get_corresponding_attr(obj) def visit_clauselist(self, clause): - evaluators = map(self.process, clause.clauses) + evaluators = list(map(self.process, clause.clauses)) if clause.operator is operators.or_: def evaluate(obj): has_null = False @@ -64,12 +87,15 @@ class EvaluatorCompiler(object): return False return True else: - raise UnevaluatableError("Cannot evaluate clauselist with operator %s" % clause.operator) + raise UnevaluatableError( + "Cannot evaluate clauselist with operator %s" % + clause.operator) return evaluate def visit_binary(self, clause): - eval_left,eval_right = map(self.process, [clause.left, clause.right]) + eval_left, eval_right = list(map(self.process, + [clause.left, clause.right])) operator = clause.operator if operator is operators.is_: def evaluate(obj): @@ -85,7 +111,9 @@ class EvaluatorCompiler(object): return None return operator(eval_left(obj), eval_right(obj)) else: - raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator)) + raise UnevaluatableError( + "Cannot evaluate %s with operator %s" % + (type(clause).__name__, clause.operator)) return evaluate def visit_unary(self, clause): @@ -97,8 +125,13 @@ class EvaluatorCompiler(object): return None return not value return evaluate - raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator)) + raise UnevaluatableError( + "Cannot evaluate %s with operator %s" % + (type(clause).__name__, clause.operator)) def visit_bindparam(self, clause): - val = clause.value + if clause.callable: + val = clause.callable() + else: + val = clause.value return lambda obj: val diff --git a/sqlalchemy/orm/exc.py b/sqlalchemy/orm/exc.py index 431acc1..c13bb67 100644 --- a/sqlalchemy/orm/exc.py +++ b/sqlalchemy/orm/exc.py @@ -1,42 +1,79 @@ -# exc.py - ORM exceptions -# Copyright (C) the SQLAlchemy authors and contributors +# orm/exc.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """SQLAlchemy ORM exceptions.""" - -import sqlalchemy as sa - +from .. import exc as sa_exc, util NO_STATE = (AttributeError, KeyError) """Exception types that may be raised by instrumentation implementations.""" -class ConcurrentModificationError(sa.exc.SQLAlchemyError): - """Rows have been modified outside of the unit of work.""" + +class StaleDataError(sa_exc.SQLAlchemyError): + """An operation encountered database state that is unaccounted for. + + Conditions which cause this to happen include: + + * A flush may have attempted to update or delete rows + and an unexpected number of rows were matched during + the UPDATE or DELETE statement. Note that when + version_id_col is used, rows in UPDATE or DELETE statements + are also matched against the current known version + identifier. + + * A mapped object with version_id_col was refreshed, + and the version number coming back from the database does + not match that of the object itself. + + * A object is detached from its parent object, however + the object was previously attached to a different parent + identity which was garbage collected, and a decision + cannot be made if the new parent was really the most + recent "parent". + + .. versionadded:: 0.7.4 + + """ + +ConcurrentModificationError = StaleDataError -class FlushError(sa.exc.SQLAlchemyError): +class FlushError(sa_exc.SQLAlchemyError): """A invalid condition was detected during flush().""" -class UnmappedError(sa.exc.InvalidRequestError): - """TODO""" +class UnmappedError(sa_exc.InvalidRequestError): + """Base for exceptions that involve expected mappings not present.""" + + +class ObjectDereferencedError(sa_exc.SQLAlchemyError): + """An operation cannot complete due to an object being garbage + collected. + + """ + + +class DetachedInstanceError(sa_exc.SQLAlchemyError): + """An attempt to access unloaded attributes on a + mapped instance that is detached.""" + -class DetachedInstanceError(sa.exc.SQLAlchemyError): - """An attempt to access unloaded attributes on a mapped instance that is detached.""" - class UnmappedInstanceError(UnmappedError): """An mapping operation was requested for an unknown instance.""" - def __init__(self, obj, msg=None): + @util.dependencies("sqlalchemy.orm.base") + def __init__(self, base, obj, msg=None): if not msg: try: - mapper = sa.orm.class_mapper(type(obj)) + base.class_mapper(type(obj)) name = _safe_cls_name(type(obj)) msg = ("Class %r is mapped, but this instance lacks " - "instrumentation. This occurs when the instance is created " - "before sqlalchemy.orm.mapper(%s) was called." % (name, name)) + "instrumentation. This occurs when the instance" + "is created before sqlalchemy.orm.mapper(%s) " + "was called." % (name, name)) except UnmappedClassError: msg = _default_unmapped(type(obj)) if isinstance(obj, type): @@ -45,6 +82,9 @@ class UnmappedInstanceError(UnmappedError): 'required?' % _safe_cls_name(obj)) UnmappedError.__init__(self, msg) + def __reduce__(self): + return self.__class__, (None, self.args[0]) + class UnmappedClassError(UnmappedError): """An mapping operation was requested for an unknown class.""" @@ -54,28 +94,53 @@ class UnmappedClassError(UnmappedError): msg = _default_unmapped(cls) UnmappedError.__init__(self, msg) - -class ObjectDeletedError(sa.exc.InvalidRequestError): - """An refresh() operation failed to re-retrieve an object's row.""" + def __reduce__(self): + return self.__class__, (None, self.args[0]) -class UnmappedColumnError(sa.exc.InvalidRequestError): +class ObjectDeletedError(sa_exc.InvalidRequestError): + """A refresh operation failed to retrieve the database + row corresponding to an object's known primary key identity. + + A refresh operation proceeds when an expired attribute is + accessed on an object, or when :meth:`.Query.get` is + used to retrieve an object which is, upon retrieval, detected + as expired. A SELECT is emitted for the target row + based on primary key; if no row is returned, this + exception is raised. + + The true meaning of this exception is simply that + no row exists for the primary key identifier associated + with a persistent object. The row may have been + deleted, or in some cases the primary key updated + to a new value, outside of the ORM's management of the target + object. + + """ + @util.dependencies("sqlalchemy.orm.base") + def __init__(self, base, state, msg=None): + if not msg: + msg = "Instance '%s' has been deleted, or its "\ + "row is otherwise not present." % base.state_str(state) + + sa_exc.InvalidRequestError.__init__(self, msg) + + def __reduce__(self): + return self.__class__, (None, self.args[0]) + + +class UnmappedColumnError(sa_exc.InvalidRequestError): """Mapping operation was requested on an unknown column.""" -class NoResultFound(sa.exc.InvalidRequestError): +class NoResultFound(sa_exc.InvalidRequestError): """A database result was required but none was found.""" -class MultipleResultsFound(sa.exc.InvalidRequestError): +class MultipleResultsFound(sa_exc.InvalidRequestError): """A single database result was required but more than one were found.""" -# Legacy compat until 0.6. -sa.exc.ConcurrentModificationError = ConcurrentModificationError -sa.exc.FlushError = FlushError -sa.exc.UnmappedColumnError - def _safe_cls_name(cls): try: cls_name = '.'.join((cls.__module__, cls.__name__)) @@ -85,9 +150,11 @@ def _safe_cls_name(cls): cls_name = repr(cls) return cls_name -def _default_unmapped(cls): + +@util.dependencies("sqlalchemy.orm.base") +def _default_unmapped(base, cls): try: - mappers = sa.orm.attributes.manager_of_class(cls).mappers + mappers = base.manager_of_class(cls).mappers except NO_STATE: mappers = {} except TypeError: diff --git a/sqlalchemy/orm/identity.py b/sqlalchemy/orm/identity.py index 4650b06..ca87fa2 100644 --- a/sqlalchemy/orm/identity.py +++ b/sqlalchemy/orm/identity.py @@ -1,67 +1,66 @@ -# identity.py -# Copyright (C) the SQLAlchemy authors and contributors +# orm/identity.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php import weakref +from . import attributes +from .. import util +from .. import exc as sa_exc +from . import util as orm_util -from sqlalchemy import util as base_util -from sqlalchemy.orm import attributes - - -class IdentityMap(dict): +class IdentityMap(object): def __init__(self): - self._mutable_attrs = set() + self._dict = {} self._modified = set() self._wr = weakref.ref(self) + def keys(self): + return self._dict.keys() + def replace(self, state): raise NotImplementedError() - + def add(self, state): raise NotImplementedError() - - def remove(self, state): - raise NotImplementedError() - + + def _add_unpresent(self, state, key): + """optional inlined form of add() which can assume item isn't present + in the map""" + self.add(state) + def update(self, dict): raise NotImplementedError("IdentityMap uses add() to insert data") - + def clear(self): raise NotImplementedError("IdentityMap uses remove() to remove data") - + def _manage_incoming_state(self, state): state._instance_dict = self._wr - + if state.modified: - self._modified.add(state) - if state.manager.mutable_attributes: - self._mutable_attrs.add(state) - + self._modified.add(state) + def _manage_removed_state(self, state): del state._instance_dict - self._mutable_attrs.discard(state) - self._modified.discard(state) + if state.modified: + self._modified.discard(state) def _dirty_states(self): - return self._modified.union(s for s in self._mutable_attrs.copy() - if s.modified) + return self._modified def check_modified(self): - """return True if any InstanceStates present have been marked as 'modified'.""" - - if self._modified: - return True - else: - for state in self._mutable_attrs.copy(): - if state.modified: - return True - return False - + """return True if any InstanceStates present have been marked + as 'modified'. + + """ + return bool(self._modified) + def has_key(self, key): return key in self - + def popitem(self): raise NotImplementedError("IdentityMap uses remove() to remove data") @@ -71,6 +70,9 @@ class IdentityMap(dict): def setdefault(self, key, default=None): raise NotImplementedError("IdentityMap uses add() to insert data") + def __len__(self): + return len(self._dict) + def copy(self): raise NotImplementedError() @@ -79,164 +81,233 @@ class IdentityMap(dict): def __delitem__(self, key): raise NotImplementedError("IdentityMap uses remove() to remove data") - + + class WeakInstanceDict(IdentityMap): def __getitem__(self, key): - state = dict.__getitem__(self, key) + state = self._dict[key] o = state.obj() if o is None: - o = state._is_really_none() - if o is None: - raise KeyError, key + raise KeyError(key) return o def __contains__(self, key): try: - if dict.__contains__(self, key): - state = dict.__getitem__(self, key) + if key in self._dict: + state = self._dict[key] o = state.obj() - if o is None: - o = state._is_really_none() else: return False except KeyError: return False else: return o is not None - + def contains_state(self, state): - return dict.get(self, state.key) is state - + return state.key in self._dict and self._dict[state.key] is state + def replace(self, state): - if dict.__contains__(self, state.key): - existing = dict.__getitem__(self, state.key) + if state.key in self._dict: + existing = self._dict[state.key] if existing is not state: self._manage_removed_state(existing) else: return - - dict.__setitem__(self, state.key, state) + + self._dict[state.key] = state self._manage_incoming_state(state) - + def add(self, state): - if state.key in self: - if dict.__getitem__(self, state.key) is not state: - raise AssertionError("A conflicting state is already " - "present in the identity map for key %r" - % (state.key, )) - else: - dict.__setitem__(self, state.key, state) - self._manage_incoming_state(state) - - def remove_key(self, key): - state = dict.__getitem__(self, key) - self.remove(state) - - def remove(self, state): - if dict.pop(self, state.key) is not state: - raise AssertionError("State %s is not present in this identity map" % state) - self._manage_removed_state(state) - - def discard(self, state): - if self.contains_state(state): - dict.__delitem__(self, state.key) - self._manage_removed_state(state) - + key = state.key + # inline of self.__contains__ + if key in self._dict: + try: + existing_state = self._dict[key] + if existing_state is not state: + o = existing_state.obj() + if o is not None: + raise sa_exc.InvalidRequestError( + "Can't attach instance " + "%s; another instance with key %s is already " + "present in this session." % ( + orm_util.state_str(state), state.key)) + else: + return False + except KeyError: + pass + self._dict[key] = state + self._manage_incoming_state(state) + return True + + def _add_unpresent(self, state, key): + # inlined form of add() called by loading.py + self._dict[key] = state + state._instance_dict = self._wr + def get(self, key, default=None): - state = dict.get(self, key, default) - if state is default: + if key not in self._dict: return default + state = self._dict[key] o = state.obj() - if o is None: - o = state._is_really_none() if o is None: return default return o - - # Py2K - def items(self): - return list(self.iteritems()) - def iteritems(self): - for state in dict.itervalues(self): - # end Py2K - # Py3K - #def items(self): - # for state in dict.values(self): + def items(self): + values = self.all_states() + result = [] + for state in values: value = state.obj() if value is not None: - yield state.key, value + result.append((state.key, value)) + return result - # Py2K def values(self): - return list(self.itervalues()) + values = self.all_states() + result = [] + for state in values: + value = state.obj() + if value is not None: + result.append(value) - def itervalues(self): - for state in dict.itervalues(self): - # end Py2K - # Py3K - #def values(self): - # for state in dict.values(self): - instance = state.obj() - if instance is not None: - yield instance + return result + + def __iter__(self): + return iter(self.keys()) + + if util.py2k: + + def iteritems(self): + return iter(self.items()) + + def itervalues(self): + return iter(self.values()) def all_states(self): - # Py3K - # return list(dict.values(self)) - - # Py2K - return dict.values(self) - # end Py2K - + if util.py2k: + return self._dict.values() + else: + return list(self._dict.values()) + + def _fast_discard(self, state): + self._dict.pop(state.key, None) + + def discard(self, state): + st = self._dict.pop(state.key, None) + if st: + assert st is state + self._manage_removed_state(state) + + def safe_discard(self, state): + if state.key in self._dict: + st = self._dict[state.key] + if st is state: + self._dict.pop(state.key, None) + self._manage_removed_state(state) + def prune(self): return 0 - + + class StrongInstanceDict(IdentityMap): + """A 'strong-referencing' version of the identity map. + + .. deprecated 1.1:: + The strong + reference identity map is legacy. See the + recipe at :ref:`session_referencing_behavior` for + an event-based approach to maintaining strong identity + references. + + + """ + + if util.py2k: + def itervalues(self): + return self._dict.itervalues() + + def iteritems(self): + return self._dict.iteritems() + + def __iter__(self): + return iter(self.dict_) + + def __getitem__(self, key): + return self._dict[key] + + def __contains__(self, key): + return key in self._dict + + def get(self, key, default=None): + return self._dict.get(key, default) + + def values(self): + return self._dict.values() + + def items(self): + return self._dict.items() + def all_states(self): - return [attributes.instance_state(o) for o in self.itervalues()] - + return [attributes.instance_state(o) for o in self.values()] + def contains_state(self, state): - return state.key in self and attributes.instance_state(self[state.key]) is state - + return ( + state.key in self and + attributes.instance_state(self[state.key]) is state) + def replace(self, state): - if dict.__contains__(self, state.key): - existing = dict.__getitem__(self, state.key) + if state.key in self._dict: + existing = self._dict[state.key] existing = attributes.instance_state(existing) if existing is not state: self._manage_removed_state(existing) else: return - dict.__setitem__(self, state.key, state.obj()) + self._dict[state.key] = state.obj() self._manage_incoming_state(state) def add(self, state): if state.key in self: - if attributes.instance_state(dict.__getitem__(self, state.key)) is not state: - raise AssertionError("A conflicting state is already present in the identity map for key %r" % (state.key, )) + if attributes.instance_state(self._dict[state.key]) is not state: + raise sa_exc.InvalidRequestError( + "Can't attach instance " + "%s; another instance with key %s is already " + "present in this session." % ( + orm_util.state_str(state), state.key)) + return False else: - dict.__setitem__(self, state.key, state.obj()) + self._dict[state.key] = state.obj() self._manage_incoming_state(state) - - def remove(self, state): - if attributes.instance_state(dict.pop(self, state.key)) is not state: - raise AssertionError("State %s is not present in this identity map" % state) - self._manage_removed_state(state) - + return True + + def _add_unpresent(self, state, key): + # inlined form of add() called by loading.py + self._dict[key] = state.obj() + state._instance_dict = self._wr + + def _fast_discard(self, state): + self._dict.pop(state.key, None) + def discard(self, state): - if self.contains_state(state): - dict.__delitem__(self, state.key) + obj = self._dict.pop(state.key, None) + if obj is not None: self._manage_removed_state(state) - - def remove_key(self, key): - state = attributes.instance_state(dict.__getitem__(self, key)) - self.remove(state) + st = attributes.instance_state(obj) + assert st is state + + def safe_discard(self, state): + if state.key in self._dict: + obj = self._dict[state.key] + st = attributes.instance_state(obj) + if st is state: + self._dict.pop(state.key, None) + self._manage_removed_state(state) def prune(self): """prune unreferenced, non-dirty states.""" - + ref_count = len(self) dirty = [s.obj() for s in self.all_states() if s.modified] @@ -244,8 +315,7 @@ class StrongInstanceDict(IdentityMap): keepers = weakref.WeakValueDictionary() keepers.update(self) - dict.clear(self) - dict.update(self, keepers) + self._dict.clear() + self._dict.update(keepers) self.modified = bool(dirty) return ref_count - len(self) - diff --git a/sqlalchemy/orm/interfaces.py b/sqlalchemy/orm/interfaces.py index 7fbb086..fbe8f50 100644 --- a/sqlalchemy/orm/interfaces.py +++ b/sqlalchemy/orm/interfaces.py @@ -1,413 +1,113 @@ -# interfaces.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# orm/interfaces.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """ -Semi-private module containing various base classes used throughout the ORM. +Contains various base classes used throughout the ORM. -Defines the extension classes :class:`MapperExtension`, -:class:`SessionExtension`, and :class:`AttributeExtension` as -well as other user-subclassable extension objects. +Defines some key base classes prominent within the internals, +as well as the now-deprecated ORM extension classes. + +Other than the deprecated extensions, this module and the +classes within are mostly private, though some attributes +are exposed when inspecting mappings. """ -from itertools import chain +from __future__ import absolute_import -import sqlalchemy.exceptions as sa_exc -from sqlalchemy import log, util -from sqlalchemy.sql import expression +from .. import util +from ..sql import operators +from .base import (ONETOMANY, MANYTOONE, MANYTOMANY, + EXT_CONTINUE, EXT_STOP, NOT_EXTENSION) +from .base import (InspectionAttr, InspectionAttr, + InspectionAttrInfo, _MappedAttribute) +import collections +from .. import inspect +from . import path_registry -class_mapper = None -collections = None +# imported later +MapperExtension = SessionExtension = AttributeExtension = None __all__ = ( 'AttributeExtension', 'EXT_CONTINUE', 'EXT_STOP', - 'ExtensionOption', - 'InstrumentationManager', + 'ONETOMANY', + 'MANYTOMANY', + 'MANYTOONE', + 'NOT_EXTENSION', 'LoaderStrategy', 'MapperExtension', 'MapperOption', 'MapperProperty', 'PropComparator', - 'PropertyOption', 'SessionExtension', - 'StrategizedOption', 'StrategizedProperty', - 'build_path', +) + + +class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): + """Represent a particular class attribute mapped by :class:`.Mapper`. + + The most common occurrences of :class:`.MapperProperty` are the + mapped :class:`.Column`, which is represented in a mapping as + an instance of :class:`.ColumnProperty`, + and a reference to another class produced by :func:`.relationship`, + represented in the mapping as an instance of + :class:`.RelationshipProperty`. + + """ + + __slots__ = ( + '_configure_started', '_configure_finished', 'parent', 'key', + 'info' ) -EXT_CONTINUE = util.symbol('EXT_CONTINUE') -EXT_STOP = util.symbol('EXT_STOP') + cascade = frozenset() + """The set of 'cascade' attribute names. -ONETOMANY = util.symbol('ONETOMANY') -MANYTOONE = util.symbol('MANYTOONE') -MANYTOMANY = util.symbol('MANYTOMANY') + This collection is checked before the 'cascade_iterator' method is called. -class MapperExtension(object): - """Base implementation for customizing ``Mapper`` behavior. - - New extension classes subclass ``MapperExtension`` and are specified - using the ``extension`` mapper() argument, which is a single - ``MapperExtension`` or a list of such. A single mapper - can maintain a chain of ``MapperExtension`` objects. When a - particular mapping event occurs, the corresponding method - on each ``MapperExtension`` is invoked serially, and each method - has the ability to halt the chain from proceeding further. - - Each ``MapperExtension`` method returns the symbol - EXT_CONTINUE by default. This symbol generally means "move - to the next ``MapperExtension`` for processing". For methods - that return objects like translated rows or new object - instances, EXT_CONTINUE means the result of the method - should be ignored. In some cases it's required for a - default mapper activity to be performed, such as adding a - new instance to a result list. - - The symbol EXT_STOP has significance within a chain - of ``MapperExtension`` objects that the chain will be stopped - when this symbol is returned. Like EXT_CONTINUE, it also - has additional significance in some cases that a default - mapper activity will not be performed. + The collection typically only applies to a RelationshipProperty. """ - def instrument_class(self, mapper, class_): - """Receive a class when the mapper is first constructed, and has - applied instrumentation to the mapped class. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. - - """ - return EXT_CONTINUE - def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): - """Receive an instance when it's constructor is called. - - This method is only called during a userland construction of - an object. It is not called when an object is loaded from the - database. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. + is_property = True + """Part of the InspectionAttr interface; states this object is a + mapper property. - """ - return EXT_CONTINUE - - def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): - """Receive an instance when it's constructor has been called, - and raised an exception. - - This method is only called during a userland construction of - an object. It is not called when an object is loaded from the - database. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. - - """ - return EXT_CONTINUE - - def translate_row(self, mapper, context, row): - """Perform pre-processing on the given result row and return a - new row instance. - - This is called when the mapper first receives a row, before - the object identity or the instance itself has been derived - from that row. The given row may or may not be a - ``RowProxy`` object - it will always be a dictionary-like - object which contains mapped columns as keys. The - returned object should also be a dictionary-like object - which recognizes mapped columns as keys. - - If the ultimate return value is EXT_CONTINUE, the row - is not translated. - - """ - return EXT_CONTINUE - - def create_instance(self, mapper, selectcontext, row, class_): - """Receive a row when a new object instance is about to be - created from that row. - - The method can choose to create the instance itself, or it can return - EXT_CONTINUE to indicate normal object creation should take place. - - mapper - The mapper doing the operation - - selectcontext - The QueryContext generated from the Query. - - row - The result row from the database - - class\_ - The class we are mapping. - - return value - A new object instance, or EXT_CONTINUE - - """ - return EXT_CONTINUE - - def append_result(self, mapper, selectcontext, row, instance, result, **flags): - """Receive an object instance before that instance is appended - to a result list. - - If this method returns EXT_CONTINUE, result appending will proceed - normally. if this method returns any other value or None, - result appending will not proceed for this instance, giving - this extension an opportunity to do the appending itself, if - desired. - - mapper - The mapper doing the operation. - - selectcontext - The QueryContext generated from the Query. - - row - The result row from the database. - - instance - The object instance to be appended to the result. - - result - List to which results are being appended. - - \**flags - extra information about the row, same as criterion in - ``create_row_processor()`` method of :class:`~sqlalchemy.orm.interfaces.MapperProperty` - """ - - return EXT_CONTINUE - - def populate_instance(self, mapper, selectcontext, row, instance, **flags): - """Receive an instance before that instance has - its attributes populated. - - This usually corresponds to a newly loaded instance but may - also correspond to an already-loaded instance which has - unloaded attributes to be populated. The method may be called - many times for a single instance, as multiple result rows are - used to populate eagerly loaded collections. - - If this method returns EXT_CONTINUE, instance population will - proceed normally. If any other value or None is returned, - instance population will not proceed, giving this extension an - opportunity to populate the instance itself, if desired. - - As of 0.5, most usages of this hook are obsolete. For a - generic "object has been newly created from a row" hook, use - ``reconstruct_instance()``, or the ``@orm.reconstructor`` - decorator. - - """ - return EXT_CONTINUE - - def reconstruct_instance(self, mapper, instance): - """Receive an object instance after it has been created via - ``__new__``, and after initial attribute population has - occurred. - - This typically occurs when the instance is created based on - incoming result rows, and is only called once for that - instance's lifetime. - - Note that during a result-row load, this method is called upon - the first row received for this instance. Note that some - attributes and collections may or may not be loaded or even - initialized, depending on what's present in the result rows. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. - - """ - return EXT_CONTINUE - - def before_insert(self, mapper, connection, instance): - """Receive an object instance before that instance is inserted - into its table. - - This is a good place to set up primary key values and such - that aren't handled otherwise. - - Column-based attributes can be modified within this method - which will result in the new value being inserted. However - *no* changes to the overall flush plan can be made, and - manipulation of the ``Session`` will not have the desired effect. - To manipulate the ``Session`` within an extension, use - ``SessionExtension``. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. - - """ - - return EXT_CONTINUE - - def after_insert(self, mapper, connection, instance): - """Receive an object instance after that instance is inserted. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. - - """ - - return EXT_CONTINUE - - def before_update(self, mapper, connection, instance): - """Receive an object instance before that instance is updated. - - Note that this method is called for all instances that are marked as - "dirty", even those which have no net changes to their column-based - attributes. An object is marked as dirty when any of its column-based - attributes have a "set attribute" operation called or when any of its - collections are modified. If, at update time, no column-based attributes - have any net changes, no UPDATE statement will be issued. This means - that an instance being sent to before_update is *not* a guarantee that - an UPDATE statement will be issued (although you can affect the outcome - here). - - To detect if the column-based attributes on the object have net changes, - and will therefore generate an UPDATE statement, use - ``object_session(instance).is_modified(instance, include_collections=False)``. - - Column-based attributes can be modified within this method - which will result in the new value being updated. However - *no* changes to the overall flush plan can be made, and - manipulation of the ``Session`` will not have the desired effect. - To manipulate the ``Session`` within an extension, use - ``SessionExtension``. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. - - """ - - return EXT_CONTINUE - - def after_update(self, mapper, connection, instance): - """Receive an object instance after that instance is updated. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. - - """ - - return EXT_CONTINUE - - def before_delete(self, mapper, connection, instance): - """Receive an object instance before that instance is deleted. - - Note that *no* changes to the overall flush plan can be made - here; and manipulation of the ``Session`` will not have the - desired effect. To manipulate the ``Session`` within an - extension, use ``SessionExtension``. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. - - """ - - return EXT_CONTINUE - - def after_delete(self, mapper, connection, instance): - """Receive an object instance after that instance is deleted. - - The return value is only significant within the ``MapperExtension`` - chain; the parent mapper's behavior isn't modified by this method. - - """ - - return EXT_CONTINUE - -class SessionExtension(object): - """An extension hook object for Sessions. Subclasses may be installed into a Session - (or sessionmaker) using the ``extension`` keyword argument. """ - def before_commit(self, session): - """Execute right before commit is called. + def _memoized_attr_info(self): + """Info dictionary associated with the object, allowing user-defined + data to be associated with this :class:`.InspectionAttr`. - Note that this may not be per-flush if a longer running transaction is ongoing.""" + The dictionary is generated when first accessed. Alternatively, + it can be specified as a constructor argument to the + :func:`.column_property`, :func:`.relationship`, or :func:`.composite` + functions. - def after_commit(self, session): - """Execute after a commit has occured. + .. versionadded:: 0.8 Added support for .info to all + :class:`.MapperProperty` subclasses. - Note that this may not be per-flush if a longer running transaction is ongoing.""" + .. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also + available on extension types via the + :attr:`.InspectionAttrInfo.info` attribute, so that it can apply + to a wider variety of ORM and extension constructs. - def after_rollback(self, session): - """Execute after a rollback has occured. + .. seealso:: - Note that this may not be per-flush if a longer running transaction is ongoing.""" + :attr:`.QueryableAttribute.info` - def before_flush(self, session, flush_context, instances): - """Execute before flush process has started. + :attr:`.SchemaItem.info` - `instances` is an optional list of objects which were passed to the ``flush()`` - method. """ - - def after_flush(self, session, flush_context): - """Execute after flush has completed, but before commit has been called. - - Note that the session's state is still in pre-flush, i.e. 'new', 'dirty', - and 'deleted' lists still show pre-flush state as well as the history - settings on instance attributes.""" - - def after_flush_postexec(self, session, flush_context): - """Execute after flush has completed, and after the post-exec state occurs. - - This will be when the 'new', 'dirty', and 'deleted' lists are in their final - state. An actual commit() may or may not have occured, depending on whether or not - the flush started its own transaction or participated in a larger transaction. - """ - - def after_begin(self, session, transaction, connection): - """Execute after a transaction is begun on a connection - - `transaction` is the SessionTransaction. This method is called after an - engine level transaction is begun on a connection. - """ - - def after_attach(self, session, instance): - """Execute after an instance is attached to a session. - - This is called after an add, delete or merge. - """ - - def after_bulk_update(self, session, query, query_context, result): - """Execute after a bulk update operation to the session. - - This is called after a session.query(...).update() - - `query` is the query object that this update operation was called on. - `query_context` was the query context object. - `result` is the result object returned from the bulk operation. - """ - - def after_bulk_delete(self, session, query, query_context, result): - """Execute after a bulk delete operation to the session. - - This is called after a session.query(...).delete() - - `query` is the query object that this delete operation was called on. - `query_context` was the query context object. - `result` is the result object returned from the bulk operation. - """ - -class MapperProperty(object): - """Manage the relationship of a ``Mapper`` to a single class - attribute, as well as that attribute as it appears on individual - instances of the class, including attribute instrumentation, - attribute access, loading behavior, and dependency calculations. - """ + return {} def setup(self, context, entity, path, adapter, **kwargs): """Called by Query for the purposes of constructing a SQL statement. @@ -415,62 +115,64 @@ class MapperProperty(object): Each MapperProperty associated with the target mapper processes the statement referenced by the query context, adding columns and/or criterion as appropriate. + """ - pass + def create_row_processor(self, context, path, + mapper, result, adapter, populators): + """Produce row processing functions and append to the given + set of populators lists. - def create_row_processor(self, selectcontext, path, mapper, row, adapter): - """Return a 2-tuple consiting of two row processing functions and - an instance post-processing function. - - Input arguments are the query.SelectionContext and the *first* - applicable row of a result set obtained within - query.Query.instances(), called only the first time a particular - mapper's populate_instance() method is invoked for the overall result. - - The settings contained within the SelectionContext as well as the - columns present in the row (which will be the same columns present in - all rows) are used to determine the presence and behavior of the - returned callables. The callables will then be used to process all - rows and instances. - - Callables are of the following form:: - - def new_execute(state, dict_, row, isnew): - # process incoming instance state and given row. the instance is - # "new" and was just created upon receipt of this row. - "isnew" indicates if the instance was newly created as a - result of reading this row - - def existing_execute(state, dict_, row): - # process incoming instance state and given row. the instance is - # "existing" and was created based on a previous row. - - return (new_execute, existing_execute) - - Either of the three tuples can be ``None`` in which case no function - is called. """ - raise NotImplementedError() - - def cascade_iterator(self, type_, state, visited_instances=None, halt_on=None): + def cascade_iterator(self, type_, state, visited_instances=None, + halt_on=None): """Iterate through instances related to the given instance for a particular 'cascade', starting with this MapperProperty. - See PropertyLoader for the related instance implementation. + Return an iterator3-tuples (instance, mapper, state). + + Note that the 'cascade' collection on this MapperProperty is + checked first for the given type before cascade_iterator is called. + + This method typically only applies to RelationshipProperty. + """ return iter(()) - def set_parent(self, parent): + def set_parent(self, parent, init): + """Set the parent mapper that references this MapperProperty. + + This method is overridden by some subclasses to perform extra + setup when the mapper is first known. + + """ self.parent = parent def instrument_class(self, mapper): - raise NotImplementedError() + """Hook called by the Mapper to the property to initiate + instrumentation of the class attribute managed by this + MapperProperty. - _compile_started = False - _compile_finished = False + The MapperProperty here will typically call out to the + attributes module to set up an InstrumentedAttribute. + + This step is the first of two steps to set up an InstrumentedAttribute, + and is called early in the mapper setup process. + + The second step is typically the init_class_attribute step, + called from StrategizedProperty via the post_instrument_class() + hook. This step assigns additional state to the InstrumentedAttribute + (specifically the "impl") which has been determined after the + MapperProperty has determined what kind of persistence + management it needs to do (e.g. scalar, object, collection, etc). + + """ + + def __init__(self): + self._configure_started = False + self._configure_finished = False def init(self): """Called after all mappers are created to assemble @@ -478,103 +180,200 @@ class MapperProperty(object): initialization steps. """ - self._compile_started = True + self._configure_started = True self.do_init() - self._compile_finished = True + self._configure_finished = True @property def class_attribute(self): - """Return the class-bound descriptor corresponding to this MapperProperty.""" - - return getattr(self.parent.class_, self.key) - - def do_init(self): - """Perform subclass-specific initialization post-mapper-creation steps. + """Return the class-bound descriptor corresponding to this + :class:`.MapperProperty`. + + This is basically a ``getattr()`` call:: + + return getattr(self.parent.class_, self.key) + + I.e. if this :class:`.MapperProperty` were named ``addresses``, + and the class to which it is mapped is ``User``, this sequence + is possible:: + + >>> from sqlalchemy import inspect + >>> mapper = inspect(User) + >>> addresses_property = mapper.attrs.addresses + >>> addresses_property.class_attribute is User.addresses + True + >>> User.addresses.property is addresses_property + True - This is a *template* method called by the - ``MapperProperty`` object's init() method. """ - pass + + return getattr(self.parent.class_, self.key) + + def do_init(self): + """Perform subclass-specific initialization post-mapper-creation + steps. + + This is a template method called by the ``MapperProperty`` + object's init() method. + + """ def post_instrument_class(self, mapper): """Perform instrumentation adjustments that need to occur after init() has completed. - """ - pass + The given Mapper is the Mapper invoking the operation, which + may not be the same Mapper as self.parent in an inheritance + scenario; however, Mapper will always at least be a sub-mapper of + self.parent. + + This method is typically used by StrategizedProperty, which delegates + it to LoaderStrategy.init_class_attribute() to perform final setup + on the class-bound InstrumentedAttribute. - def register_dependencies(self, *args, **kwargs): - """Called by the ``Mapper`` in response to the UnitOfWork - calling the ``Mapper``'s register_dependencies operation. - Establishes a topological dependency between two mappers - which will affect the order in which mappers persist data. - """ - pass - - def register_processors(self, *args, **kwargs): - """Called by the ``Mapper`` in response to the UnitOfWork - calling the ``Mapper``'s register_processors operation. - Establishes a processor object between two mappers which - will link data and state between parent/child objects. - - """ - - pass - - def is_primary(self): - """Return True if this ``MapperProperty``'s mapper is the - primary mapper for its class. - - This flag is used to indicate that the ``MapperProperty`` can - define attribute instrumentation for the class at the class - level (as opposed to the individual instance level). - """ - - return not self.parent.non_primary - - def merge(self, session, source, dest, load, _recursive): + def merge(self, session, source_state, source_dict, dest_state, + dest_dict, load, _recursive, _resolve_conflict_map): """Merge the attribute represented by this ``MapperProperty`` - from source to destination object""" + from source to destination object. - raise NotImplementedError() - - def compare(self, operator, value): - """Return a compare operation for the columns represented by - this ``MapperProperty`` to the given value, which may be a - column value or an instance. 'operator' is an operator from - the operators module, or from sql.Comparator. - - By default uses the PropComparator attached to this MapperProperty - under the attribute name "comparator". """ - return operator(self.comparator, value) + def __repr__(self): + return '<%s at 0x%x; %s>' % ( + self.__class__.__name__, + id(self), getattr(self, 'key', 'no key')) -class PropComparator(expression.ColumnOperators): - """defines comparison operations for MapperProperty objects. - PropComparator instances should also define an accessor 'property' - which returns the MapperProperty associated with this - PropComparator. +class PropComparator(operators.ColumnOperators): + r"""Defines SQL operators for :class:`.MapperProperty` objects. + + SQLAlchemy allows for operators to + be redefined at both the Core and ORM level. :class:`.PropComparator` + is the base class of operator redefinition for ORM-level operations, + including those of :class:`.ColumnProperty`, + :class:`.RelationshipProperty`, and :class:`.CompositeProperty`. + + .. note:: With the advent of Hybrid properties introduced in SQLAlchemy + 0.7, as well as Core-level operator redefinition in + SQLAlchemy 0.8, the use case for user-defined :class:`.PropComparator` + instances is extremely rare. See :ref:`hybrids_toplevel` as well + as :ref:`types_operators`. + + User-defined subclasses of :class:`.PropComparator` may be created. The + built-in Python comparison and math operator methods, such as + :meth:`.operators.ColumnOperators.__eq__`, + :meth:`.operators.ColumnOperators.__lt__`, and + :meth:`.operators.ColumnOperators.__add__`, can be overridden to provide + new operator behavior. The custom :class:`.PropComparator` is passed to + the :class:`.MapperProperty` instance via the ``comparator_factory`` + argument. In each case, + the appropriate subclass of :class:`.PropComparator` should be used:: + + # definition of custom PropComparator subclasses + + from sqlalchemy.orm.properties import \ + ColumnProperty,\ + CompositeProperty,\ + RelationshipProperty + + class MyColumnComparator(ColumnProperty.Comparator): + def __eq__(self, other): + return self.__clause_element__() == other + + class MyRelationshipComparator(RelationshipProperty.Comparator): + def any(self, expression): + "define the 'any' operation" + # ... + + class MyCompositeComparator(CompositeProperty.Comparator): + def __gt__(self, other): + "redefine the 'greater than' operation" + + return sql.and_(*[a>b for a, b in + zip(self.__clause_element__().clauses, + other.__composite_values__())]) + + + # application of custom PropComparator subclasses + + from sqlalchemy.orm import column_property, relationship, composite + from sqlalchemy import Column, String + + class SomeMappedClass(Base): + some_column = column_property(Column("some_column", String), + comparator_factory=MyColumnComparator) + + some_relationship = relationship(SomeOtherClass, + comparator_factory=MyRelationshipComparator) + + some_composite = composite( + Column("a", String), Column("b", String), + comparator_factory=MyCompositeComparator + ) + + Note that for column-level operator redefinition, it's usually + simpler to define the operators at the Core level, using the + :attr:`.TypeEngine.comparator_factory` attribute. See + :ref:`types_operators` for more detail. + + See also: + + :class:`.ColumnProperty.Comparator` + + :class:`.RelationshipProperty.Comparator` + + :class:`.CompositeProperty.Comparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + """ - def __init__(self, prop, mapper, adapter=None): + __slots__ = 'prop', 'property', '_parententity', '_adapt_to_entity' + + def __init__(self, prop, parentmapper, adapt_to_entity=None): self.prop = self.property = prop - self.mapper = mapper - self.adapter = adapter + self._parententity = adapt_to_entity or parentmapper + self._adapt_to_entity = adapt_to_entity def __clause_element__(self): raise NotImplementedError("%r" % self) - def adapted(self, adapter): - """Return a copy of this PropComparator which will use the given adaption function - on the local side of generated expressions. + def _query_clause_element(self): + return self.__clause_element__() + + def adapt_to_entity(self, adapt_to_entity): + """Return a copy of this PropComparator which will use the given + :class:`.AliasedInsp` to produce corresponding expressions. + """ + return self.__class__(self.prop, self._parententity, adapt_to_entity) + + @property + def _parentmapper(self): + """legacy; this is renamed to _parententity to be + compatible with QueryableAttribute.""" + return inspect(self._parententity).mapper + + @property + def adapter(self): + """Produce a callable that adapts column expressions + to suit an aliased version of this comparator. """ - return self.__class__(self.prop, self.mapper, adapter) + if self._adapt_to_entity is None: + return None + else: + return self._adapt_to_entity._adapt_element + + @property + def info(self): + return self.property.info @staticmethod def any_op(a, b, **kwargs): @@ -589,18 +388,18 @@ class PropComparator(expression.ColumnOperators): return a.of_type(class_) def of_type(self, class_): - """Redefine this object in terms of a polymorphic subclass. + r"""Redefine this object in terms of a polymorphic subclass. - Returns a new PropComparator from which further criterion can be evaluated. + Returns a new PropComparator from which further criterion can be + evaluated. e.g.:: - query.join(Company.employees.of_type(Engineer)).\\ + query.join(Company.employees.of_type(Engineer)).\ filter(Engineer.name=='foo') - \class_ - a class or mapper indicating that criterion will be against - this specific subclass. + :param \class_: a class or mapper indicating that criterion will be + against this specific subclass. """ @@ -608,29 +407,37 @@ class PropComparator(expression.ColumnOperators): return self.operate(PropComparator.of_type_op, class_) def any(self, criterion=None, **kwargs): - """Return true if this collection contains any member that meets the given criterion. + r"""Return true if this collection contains any member that meets the + given criterion. - criterion - an optional ClauseElement formulated against the member class' table - or attributes. + The usual implementation of ``any()`` is + :meth:`.RelationshipProperty.Comparator.any`. + + :param criterion: an optional ClauseElement formulated against the + member class' table or attributes. + + :param \**kwargs: key/value pairs corresponding to member class + attribute names which will be compared via equality to the + corresponding values. - \**kwargs - key/value pairs corresponding to member class attribute names which - will be compared via equality to the corresponding values. """ return self.operate(PropComparator.any_op, criterion, **kwargs) def has(self, criterion=None, **kwargs): - """Return true if this element references a member which meets the given criterion. + r"""Return true if this element references a member which meets the + given criterion. - criterion - an optional ClauseElement formulated against the member class' table - or attributes. + The usual implementation of ``has()`` is + :meth:`.RelationshipProperty.Comparator.has`. + + :param criterion: an optional ClauseElement formulated against the + member class' table or attributes. + + :param \**kwargs: key/value pairs corresponding to member class + attribute names which will be compared via equality to the + corresponding values. - \**kwargs - key/value pairs corresponding to member class attribute names which - will be compared via equality to the corresponding values. """ return self.operate(PropComparator.has_op, criterion, **kwargs) @@ -643,326 +450,146 @@ class StrategizedProperty(MapperProperty): There is a single strategy selected by default. Alternate strategies can be selected at Query time through the usage of ``StrategizedOption`` objects via the Query.options() method. - + + The mechanics of StrategizedProperty are used for every Query + invocation for every mapped attribute participating in that Query, + to determine first how the attribute will be rendered in SQL + and secondly how the attribute will retrieve a value from a result + row and apply it to a mapped object. The routines here are very + performance-critical. + """ - def _get_context_strategy(self, context, path): - cls = context.attributes.get(("loaderstrategy", _reduce_path(path)), None) - if cls: - try: - return self.__all_strategies[cls] - except KeyError: - return self.__init_strategy(cls) - else: - return self.strategy - - def _get_strategy(self, cls): - try: - return self.__all_strategies[cls] - except KeyError: - return self.__init_strategy(cls) - - def __init_strategy(self, cls): - self.__all_strategies[cls] = strategy = cls(self) - strategy.init() - return strategy - - def setup(self, context, entity, path, adapter, **kwargs): - self._get_context_strategy(context, path + (self.key,)).\ - setup_query(context, entity, path, adapter, **kwargs) - - def create_row_processor(self, context, path, mapper, row, adapter): - return self._get_context_strategy(context, path + (self.key,)).\ - create_row_processor(context, path, mapper, row, adapter) - - def do_init(self): - self.__all_strategies = {} - self.strategy = self.__init_strategy(self.strategy_class) - - def post_instrument_class(self, mapper): - if self.is_primary(): - self.strategy.init_class_attribute(mapper) - -def build_path(entity, key, prev=None): - if prev: - return prev + (entity, key) - else: - return (entity, key) - -def serialize_path(path): - if path is None: - return None - - return zip( - [m.class_ for m in [path[i] for i in range(0, len(path), 2)]], - [path[i] for i in range(1, len(path), 2)] + [None] + __slots__ = ( + '_strategies', 'strategy', + '_wildcard_token', '_default_path_loader_key' ) -def deserialize_path(path): - if path is None: - return None + strategy_wildcard_key = None + + def _memoized_attr__wildcard_token(self): + return ("%s:%s" % ( + self.strategy_wildcard_key, path_registry._WILDCARD_TOKEN), ) + + def _memoized_attr__default_path_loader_key(self): + return ( + "loader", + ("%s:%s" % ( + self.strategy_wildcard_key, path_registry._DEFAULT_TOKEN), ) + ) + + def _get_context_loader(self, context, path): + load = None + + # use EntityRegistry.__getitem__()->PropRegistry here so + # that the path is stated in terms of our base + search_path = dict.__getitem__(path, self) + + # search among: exact match, "attr.*", "default" strategy + # if any. + for path_key in ( + search_path._loader_key, + search_path._wildcard_path_loader_key, + search_path._default_path_loader_key + ): + if path_key in context.attributes: + load = context.attributes[path_key] + break + + return load + + def _get_strategy(self, key): + try: + return self._strategies[key] + except KeyError: + cls = self._strategy_lookup(*key) + self._strategies[key] = self._strategies[ + cls] = strategy = cls(self, key) + return strategy + + def setup( + self, context, entity, path, adapter, **kwargs): + loader = self._get_context_loader(context, path) + if loader and loader.strategy: + strat = self._get_strategy(loader.strategy) + else: + strat = self.strategy + strat.setup_query(context, entity, path, loader, adapter, **kwargs) + + def create_row_processor( + self, context, path, mapper, + result, adapter, populators): + loader = self._get_context_loader(context, path) + if loader and loader.strategy: + strat = self._get_strategy(loader.strategy) + else: + strat = self.strategy + strat.create_row_processor( + context, path, loader, + mapper, result, adapter, populators) + + def do_init(self): + self._strategies = {} + self.strategy = self._get_strategy(self.strategy_key) + + def post_instrument_class(self, mapper): + if not self.parent.non_primary and \ + not mapper.class_manager._attr_has_impl(self.key): + self.strategy.init_class_attribute(mapper) + + _all_strategies = collections.defaultdict(dict) + + @classmethod + def strategy_for(cls, **kw): + def decorate(dec_cls): + # ensure each subclass of the strategy has its + # own _strategy_keys collection + if '_strategy_keys' not in dec_cls.__dict__: + dec_cls._strategy_keys = [] + key = tuple(sorted(kw.items())) + cls._all_strategies[cls][key] = dec_cls + dec_cls._strategy_keys.append(key) + return dec_cls + return decorate + + @classmethod + def _strategy_lookup(cls, *key): + for prop_cls in cls.__mro__: + if prop_cls in cls._all_strategies: + strategies = cls._all_strategies[prop_cls] + try: + return strategies[key] + except KeyError: + pass + raise Exception("can't locate strategy for %s %s" % (cls, key)) - global class_mapper - if class_mapper is None: - from sqlalchemy.orm import class_mapper - - p = tuple(chain(*[(class_mapper(cls), key) for cls, key in path])) - if p and p[-1] is None: - p = p[0:-1] - return p class MapperOption(object): """Describe a modification to a Query.""" propagate_to_loaders = False - """if True, indicate this option should be carried along - Query object generated by scalar or object lazy loaders. + """if True, indicate this option should be carried along + to "secondary" Query objects produced during lazy loads + or refresh operations. + """ - + def process_query(self, query): - pass + """Apply a modification to the given :class:`.Query`.""" def process_query_conditionally(self, query): - """same as process_query(), except that this option may not apply - to the given query. + """same as process_query(), except that this option may not + apply to the given query. + + This is typically used during a lazy load or scalar refresh + operation to propagate options stated in the original Query to the + new Query being used for the load. It occurs for those options that + specify propagate_to_loaders=True. + + """ - Used when secondary loaders resend existing options to a new - Query.""" self.process_query(query) -class ExtensionOption(MapperOption): - """a MapperOption that applies a MapperExtension to a query operation.""" - - def __init__(self, ext): - self.ext = ext - - def process_query(self, query): - entity = query._generate_mapper_zero() - entity.extension = entity.extension.copy() - entity.extension.push(self.ext) - -class PropertyOption(MapperOption): - """A MapperOption that is applied to a property off the mapper or - one of its child mappers, identified by a dot-separated key. - """ - - def __init__(self, key, mapper=None): - self.key = key - self.mapper = mapper - - def process_query(self, query): - self._process(query, True) - - def process_query_conditionally(self, query): - self._process(query, False) - - def _process(self, query, raiseerr): - paths, mappers = self._get_paths(query, raiseerr) - if paths: - self.process_query_property(query, paths, mappers) - - def process_query_property(self, query, paths, mappers): - pass - - def __getstate__(self): - d = self.__dict__.copy() - d['key'] = ret = [] - for token in util.to_list(self.key): - if isinstance(token, PropComparator): - ret.append((token.mapper.class_, token.key)) - else: - ret.append(token) - return d - - def __setstate__(self, state): - ret = [] - for key in state['key']: - if isinstance(key, tuple): - cls, propkey = key - ret.append(getattr(cls, propkey)) - else: - ret.append(key) - state['key'] = tuple(ret) - self.__dict__ = state - - def _find_entity(self, query, mapper, raiseerr): - from sqlalchemy.orm.util import _class_to_mapper, _is_aliased_class - - if _is_aliased_class(mapper): - searchfor = mapper - isa = False - else: - searchfor = _class_to_mapper(mapper) - isa = True - - for ent in query._mapper_entities: - if searchfor is ent.path_entity or ( - isa and - searchfor.common_parent(ent.path_entity)): - return ent - else: - if raiseerr: - raise sa_exc.ArgumentError( - "Can't find entity %s in Query. Current list: %r" - % (searchfor, [ - str(m.path_entity) for m in query._entities - ])) - else: - return None - - def _get_paths(self, query, raiseerr): - path = None - entity = None - l = [] - mappers = [] - - # _current_path implies we're in a secondary load - # with an existing path - current_path = list(query._current_path) - - tokens = [] - for key in util.to_list(self.key): - if isinstance(key, basestring): - tokens += key.split('.') - else: - tokens += [key] - - for token in tokens: - if isinstance(token, basestring): - if not entity: - if current_path: - if current_path[1] == token: - current_path = current_path[2:] - continue - - entity = query._entity_zero() - path_element = entity.path_entity - mapper = entity.mapper - mappers.append(mapper) - prop = mapper.get_property( - token, - resolve_synonyms=True, - raiseerr=raiseerr) - key = token - elif isinstance(token, PropComparator): - prop = token.property - if not entity: - if current_path: - if current_path[0:2] == [token.parententity, prop.key]: - current_path = current_path[2:] - continue - - entity = self._find_entity( - query, - token.parententity, - raiseerr) - if not entity: - return [], [] - path_element = entity.path_entity - mapper = entity.mapper - mappers.append(prop.parent) - key = prop.key - else: - raise sa_exc.ArgumentError("mapper option expects string key " - "or list of attributes") - - if prop is None: - return [], [] - - path = build_path(path_element, prop.key, path) - l.append(path) - if getattr(token, '_of_type', None): - path_element = mapper = token._of_type - else: - path_element = mapper = getattr(prop, 'mapper', None) - - if path_element: - path_element = path_element - - - # if current_path tokens remain, then - # we didn't have an exact path match. - if current_path: - return [], [] - - return l, mappers - -class AttributeExtension(object): - """An event handler for individual attribute change events. - - AttributeExtension is assembled within the descriptors associated - with a mapped class. - - """ - - active_history = True - """indicates that the set() method would like to receive the 'old' value, - even if it means firing lazy callables. - """ - - def append(self, state, value, initiator): - """Receive a collection append event. - - The returned value will be used as the actual value to be - appended. - - """ - return value - - def remove(self, state, value, initiator): - """Receive a remove event. - - No return value is defined. - - """ - pass - - def set(self, state, value, oldvalue, initiator): - """Receive a set event. - - The returned value will be used as the actual value to be - set. - - """ - return value - - -class StrategizedOption(PropertyOption): - """A MapperOption that affects which LoaderStrategy will be used - for an operation by a StrategizedProperty. - """ - - is_chained = False - - def process_query_property(self, query, paths, mappers): - # _get_context_strategy may receive the path in terms of - # a base mapper - e.g. options(eagerload_all(Company.employees, Engineer.machines)) - # in the polymorphic tests leads to "(Person, 'machines')" in - # the path due to the mechanics of how the eager strategy builds - # up the path - if self.is_chained: - for path in paths: - query._attributes[("loaderstrategy", _reduce_path(path))] = \ - self.get_strategy_class() - else: - query._attributes[("loaderstrategy", _reduce_path(paths[-1]))] = \ - self.get_strategy_class() - - def get_strategy_class(self): - raise NotImplementedError() - -def _reduce_path(path): - """Convert a (mapper, path) path to use base mappers. - - This is used to allow more open ended selection of loader strategies, i.e. - Mapper -> prop1 -> Subclass -> prop2, where Subclass is a sub-mapper - of the mapper referened by Mapper.prop1. - - """ - return tuple([i % 2 != 0 and - path[i] or - getattr(path[i], 'base_mapper', path[i]) - for i in xrange(len(path))]) class LoaderStrategy(object): """Describe the loading behavior of a StrategizedProperty object. @@ -978,121 +605,51 @@ class LoaderStrategy(object): * it processes the ``QueryContext`` at statement construction time, where it can modify the SQL statement that is being produced. - simple column attributes may add their represented column to the - list of selected columns, *eager loading* properties may add - ``LEFT OUTER JOIN`` clauses to the statement. + For example, simple column attributes will add their represented + column to the list of selected columns, a joined eager loader + may establish join clauses to add to the statement. + + * It produces "row processor" functions at result fetching time. + These "row processor" functions populate a particular attribute + on a particular mapped instance. - * it processes the ``SelectionContext`` at row-processing time. This - includes straight population of attributes corresponding to rows, - setting instance-level lazyloader callables on newly - constructed instances, and appending child items to scalar/collection - attributes in response to eagerly-loaded relations. """ - def __init__(self, parent): + __slots__ = 'parent_property', 'is_class_level', 'parent', 'key', \ + 'strategy_key', 'strategy_opts' + + def __init__(self, parent, strategy_key): self.parent_property = parent self.is_class_level = False self.parent = self.parent_property.parent self.key = self.parent_property.key - - def init(self): - raise NotImplementedError("LoaderStrategy") + self.strategy_key = strategy_key + self.strategy_opts = dict(strategy_key) def init_class_attribute(self, mapper): pass - def setup_query(self, context, entity, path, adapter, **kwargs): - pass + def setup_query(self, context, entity, path, loadopt, adapter, **kwargs): + """Establish column and other state for a given QueryContext. - def create_row_processor(self, selectcontext, path, mapper, row, adapter): - """Return row processing functions which fulfill the contract specified - by MapperProperty.create_row_processor. + This method fulfills the contract specified by MapperProperty.setup(). + + StrategizedProperty delegates its setup() method + directly to this method. - StrategizedProperty delegates its create_row_processor method directly - to this method. """ - raise NotImplementedError() + def create_row_processor(self, context, path, loadopt, mapper, + result, adapter, populators): + """Establish row processing functions for a given QueryContext. + + This method fulfills the contract specified by + MapperProperty.create_row_processor(). + + StrategizedProperty delegates its create_row_processor() method + directly to this method. + + """ def __str__(self): return str(self.parent_property) - - def debug_callable(self, fn, logger, announcement, logfn): - if announcement: - logger.debug(announcement) - if logfn: - def call(*args, **kwargs): - logger.debug(logfn(*args, **kwargs)) - return fn(*args, **kwargs) - return call - else: - return fn - -class InstrumentationManager(object): - """User-defined class instrumentation extension. - - The API for this class should be considered as semi-stable, - and may change slightly with new releases. - - """ - - # r4361 added a mandatory (cls) constructor to this interface. - # given that, perhaps class_ should be dropped from all of these - # signatures. - - def __init__(self, class_): - pass - - def manage(self, class_, manager): - setattr(class_, '_default_class_manager', manager) - - def dispose(self, class_, manager): - delattr(class_, '_default_class_manager') - - def manager_getter(self, class_): - def get(cls): - return cls._default_class_manager - return get - - def instrument_attribute(self, class_, key, inst): - pass - - def post_configure_attribute(self, class_, key, inst): - pass - - def install_descriptor(self, class_, key, inst): - setattr(class_, key, inst) - - def uninstall_descriptor(self, class_, key): - delattr(class_, key) - - def install_member(self, class_, key, implementation): - setattr(class_, key, implementation) - - def uninstall_member(self, class_, key): - delattr(class_, key) - - def instrument_collection_class(self, class_, key, collection_class): - global collections - if collections is None: - from sqlalchemy.orm import collections - return collections.prepare_instrumentation(collection_class) - - def get_instance_dict(self, class_, instance): - return instance.__dict__ - - def initialize_instance_dict(self, class_, instance): - pass - - def install_state(self, class_, instance, state): - setattr(instance, '_default_state', state) - - def remove_state(self, class_, instance): - delattr(instance, '_default_state', state) - - def state_getter(self, class_): - return lambda instance: getattr(instance, '_default_state') - - def dict_getter(self, class_): - return lambda inst: self.get_instance_dict(class_, inst) - \ No newline at end of file diff --git a/sqlalchemy/orm/mapper.py b/sqlalchemy/orm/mapper.py index 8f0f212..962486d 100644 --- a/sqlalchemy/orm/mapper.py +++ b/sqlalchemy/orm/mapper.py @@ -1,107 +1,561 @@ -# mapper.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# orm/mapper.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """Logic to map Python classes to and from selectables. -Defines the :class:`~sqlalchemy.orm.mapper.Mapper` class, the central configurational -unit which associates a class with a database table. +Defines the :class:`~sqlalchemy.orm.mapper.Mapper` class, the central +configurational unit which associates a class with a database table. This is a semi-private module; the main configurational API of the ORM is available in :class:`~sqlalchemy.orm.`. """ +from __future__ import absolute_import import types import weakref -import operator from itertools import chain -deque = __import__('collections').deque +from collections import deque -from sqlalchemy import sql, util, log, exc as sa_exc -from sqlalchemy.sql import expression, visitors, operators, util as sqlutil -from sqlalchemy.orm import attributes, sync, exc as orm_exc -from sqlalchemy.orm.interfaces import ( - MapperProperty, EXT_CONTINUE, PropComparator - ) -from sqlalchemy.orm.util import ( - ExtensionCarrier, _INSTRUMENTOR, _class_to_mapper, _state_has_identity, - _state_mapper, class_mapper, instance_str, state_str, - ) +from .. import sql, util, log, exc as sa_exc, event, schema, inspection +from ..sql import expression, visitors, operators, util as sql_util +from . import instrumentation, attributes, exc as orm_exc, loading +from . import properties +from . import util as orm_util +from .interfaces import MapperProperty, InspectionAttr, _MappedAttribute + +from .base import _class_to_mapper, _state_mapper, class_mapper, \ + state_str, _INSTRUMENTOR +from .path_registry import PathRegistry + +import sys -__all__ = ( - 'Mapper', - '_mapper_registry', - 'class_mapper', - 'object_mapper', - ) _mapper_registry = weakref.WeakKeyDictionary() -_new_mappers = False _already_compiling = False -_none_set = frozenset([None]) -# a list of MapperExtensions that will be installed in all mappers by default -global_extensions = [] +_memoized_configured_property = util.group_expirable_memoized_property() + # a constant returned by _get_attr_by_column to indicate # this mapper is not handling an attribute for a particular # column NO_ATTRIBUTE = util.symbol('NO_ATTRIBUTE') -# lock used to synchronize the "mapper compile" step -_COMPILE_MUTEX = util.threading.RLock() +# lock used to synchronize the "mapper configure" step +_CONFIGURE_MUTEX = util.threading.RLock() -# initialize these lazily -ColumnProperty = None -SynonymProperty = None -ComparableProperty = None -RelationshipProperty = None -ConcreteInheritedProperty = None -_expire_state = None -_state_session = None -class Mapper(object): +@inspection._self_inspects +@log.class_logger +class Mapper(InspectionAttr): """Define the correlation of class attributes to database table columns. - Instances of this class should be constructed via the - :func:`~sqlalchemy.orm.mapper` function. + The :class:`.Mapper` object is instantiated using the + :func:`~sqlalchemy.orm.mapper` function. For information + about instantiating new :class:`.Mapper` objects, see + that function's documentation. + + + When :func:`.mapper` is used + explicitly to link a user defined class with table + metadata, this is referred to as *classical mapping*. + Modern SQLAlchemy usage tends to favor the + :mod:`sqlalchemy.ext.declarative` extension for class + configuration, which + makes usage of :func:`.mapper` behind the scenes. + + Given a particular class known to be mapped by the ORM, + the :class:`.Mapper` which maintains it can be acquired + using the :func:`.inspect` function:: + + from sqlalchemy import inspect + + mapper = inspect(MyClass) + + A class which was mapped by the :mod:`sqlalchemy.ext.declarative` + extension will also have its mapper available via the ``__mapper__`` + attribute. + """ + + _new_mappers = False + def __init__(self, class_, - local_table, - properties = None, - primary_key = None, - non_primary = False, - inherits = None, - inherit_condition = None, - inherit_foreign_keys = None, - extension = None, - order_by = False, - always_refresh = False, - version_id_col = None, - version_id_generator = None, + local_table=None, + properties=None, + primary_key=None, + non_primary=False, + inherits=None, + inherit_condition=None, + inherit_foreign_keys=None, + extension=None, + order_by=False, + always_refresh=False, + version_id_col=None, + version_id_generator=None, polymorphic_on=None, _polymorphic_map=None, polymorphic_identity=None, concrete=False, with_polymorphic=None, - allow_null_pks=None, allow_partial_pks=True, batch=True, column_prefix=None, include_properties=None, exclude_properties=None, passive_updates=True, - eager_defaults=False): - """Construct a new mapper. + passive_deletes=False, + confirm_deleted_rows=True, + eager_defaults=False, + legacy_is_orphan=False, + _compiled_cache_size=100, + ): + r"""Return a new :class:`~.Mapper` object. - Mappers are normally constructed via the :func:`~sqlalchemy.orm.mapper` - function. See for details. + This function is typically used behind the scenes + via the Declarative extension. When using Declarative, + many of the usual :func:`.mapper` arguments are handled + by the Declarative extension itself, including ``class_``, + ``local_table``, ``properties``, and ``inherits``. + Other options are passed to :func:`.mapper` using + the ``__mapper_args__`` class variable:: + + class MyClass(Base): + __tablename__ = 'my_table' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + alt = Column("some_alt", Integer) + + __mapper_args__ = { + 'polymorphic_on' : type + } + + + Explicit use of :func:`.mapper` + is often referred to as *classical mapping*. The above + declarative example is equivalent in classical form to:: + + my_table = Table("my_table", metadata, + Column('id', Integer, primary_key=True), + Column('type', String(50)), + Column("some_alt", Integer) + ) + + class MyClass(object): + pass + + mapper(MyClass, my_table, + polymorphic_on=my_table.c.type, + properties={ + 'alt':my_table.c.some_alt + }) + + .. seealso:: + + :ref:`classical_mapping` - discussion of direct usage of + :func:`.mapper` + + :param class\_: The class to be mapped. When using Declarative, + this argument is automatically passed as the declared class + itself. + + :param local_table: The :class:`.Table` or other selectable + to which the class is mapped. May be ``None`` if + this mapper inherits from another mapper using single-table + inheritance. When using Declarative, this argument is + automatically passed by the extension, based on what + is configured via the ``__table__`` argument or via the + :class:`.Table` produced as a result of the ``__tablename__`` + and :class:`.Column` arguments present. + + :param always_refresh: If True, all query operations for this mapped + class will overwrite all data within object instances that already + exist within the session, erasing any in-memory changes with + whatever information was loaded from the database. Usage of this + flag is highly discouraged; as an alternative, see the method + :meth:`.Query.populate_existing`. + + :param allow_partial_pks: Defaults to True. Indicates that a + composite primary key with some NULL values should be considered as + possibly existing within the database. This affects whether a + mapper will assign an incoming row to an existing identity, as well + as if :meth:`.Session.merge` will check the database first for a + particular primary key value. A "partial primary key" can occur if + one has mapped to an OUTER JOIN, for example. + + :param batch: Defaults to ``True``, indicating that save operations + of multiple entities can be batched together for efficiency. + Setting to False indicates + that an instance will be fully saved before saving the next + instance. This is used in the extremely rare case that a + :class:`.MapperEvents` listener requires being called + in between individual row persistence operations. + + :param column_prefix: A string which will be prepended + to the mapped attribute name when :class:`.Column` + objects are automatically assigned as attributes to the + mapped class. Does not affect explicitly specified + column-based properties. + + See the section :ref:`column_prefix` for an example. + + :param concrete: If True, indicates this mapper should use concrete + table inheritance with its parent mapper. + + See the section :ref:`concrete_inheritance` for an example. + + :param confirm_deleted_rows: defaults to True; when a DELETE occurs + of one more rows based on specific primary keys, a warning is + emitted when the number of rows matched does not equal the number + of rows expected. This parameter may be set to False to handle the + case where database ON DELETE CASCADE rules may be deleting some of + those rows automatically. The warning may be changed to an + exception in a future release. + + .. versionadded:: 0.9.4 - added + :paramref:`.mapper.confirm_deleted_rows` as well as conditional + matched row checking on delete. + + :param eager_defaults: if True, the ORM will immediately fetch the + value of server-generated default values after an INSERT or UPDATE, + rather than leaving them as expired to be fetched on next access. + This can be used for event schemes where the server-generated values + are needed immediately before the flush completes. By default, + this scheme will emit an individual ``SELECT`` statement per row + inserted or updated, which note can add significant performance + overhead. However, if the + target database supports :term:`RETURNING`, the default values will + be returned inline with the INSERT or UPDATE statement, which can + greatly enhance performance for an application that needs frequent + access to just-generated server defaults. + + .. versionchanged:: 0.9.0 The ``eager_defaults`` option can now + make use of :term:`RETURNING` for backends which support it. + + :param exclude_properties: A list or set of string column names to + be excluded from mapping. + + See :ref:`include_exclude_cols` for an example. + + :param extension: A :class:`.MapperExtension` instance or + list of :class:`.MapperExtension` instances which will be applied + to all operations by this :class:`.Mapper`. **Deprecated.** + Please see :class:`.MapperEvents`. + + :param include_properties: An inclusive list or set of string column + names to map. + + See :ref:`include_exclude_cols` for an example. + + :param inherits: A mapped class or the corresponding :class:`.Mapper` + of one indicating a superclass to which this :class:`.Mapper` + should *inherit* from. The mapped class here must be a subclass + of the other mapper's class. When using Declarative, this argument + is passed automatically as a result of the natural class + hierarchy of the declared classes. + + .. seealso:: + + :ref:`inheritance_toplevel` + + :param inherit_condition: For joined table inheritance, a SQL + expression which will + define how the two tables are joined; defaults to a natural join + between the two tables. + + :param inherit_foreign_keys: When ``inherit_condition`` is used and + the columns present are missing a :class:`.ForeignKey` + configuration, this parameter can be used to specify which columns + are "foreign". In most cases can be left as ``None``. + + :param legacy_is_orphan: Boolean, defaults to ``False``. + When ``True``, specifies that "legacy" orphan consideration + is to be applied to objects mapped by this mapper, which means + that a pending (that is, not persistent) object is auto-expunged + from an owning :class:`.Session` only when it is de-associated + from *all* parents that specify a ``delete-orphan`` cascade towards + this mapper. The new default behavior is that the object is + auto-expunged when it is de-associated with *any* of its parents + that specify ``delete-orphan`` cascade. This behavior is more + consistent with that of a persistent object, and allows behavior to + be consistent in more scenarios independently of whether or not an + orphanable object has been flushed yet or not. + + See the change note and example at :ref:`legacy_is_orphan_addition` + for more detail on this change. + + .. versionadded:: 0.8 - the consideration of a pending object as + an "orphan" has been modified to more closely match the + behavior as that of persistent objects, which is that the object + is expunged from the :class:`.Session` as soon as it is + de-associated from any of its orphan-enabled parents. Previously, + the pending object would be expunged only if de-associated + from all of its orphan-enabled parents. The new flag + ``legacy_is_orphan`` is added to :func:`.orm.mapper` which + re-establishes the legacy behavior. + + :param non_primary: Specify that this :class:`.Mapper` is in addition + to the "primary" mapper, that is, the one used for persistence. + The :class:`.Mapper` created here may be used for ad-hoc + mapping of the class to an alternate selectable, for loading + only. + + :paramref:`.Mapper.non_primary` is not an often used option, but + is useful in some specific :func:`.relationship` cases. + + .. seealso:: + + :ref:`relationship_non_primary_mapper` + + :param order_by: A single :class:`.Column` or list of :class:`.Column` + objects for which selection operations should use as the default + ordering for entities. By default mappers have no pre-defined + ordering. + + .. deprecated:: 1.1 The :paramref:`.Mapper.order_by` parameter + is deprecated. Use :meth:`.Query.order_by` to determine the + ordering of a result set. + + :param passive_deletes: Indicates DELETE behavior of foreign key + columns when a joined-table inheritance entity is being deleted. + Defaults to ``False`` for a base mapper; for an inheriting mapper, + defaults to ``False`` unless the value is set to ``True`` + on the superclass mapper. + + When ``True``, it is assumed that ON DELETE CASCADE is configured + on the foreign key relationships that link this mapper's table + to its superclass table, so that when the unit of work attempts + to delete the entity, it need only emit a DELETE statement for the + superclass table, and not this table. + + When ``False``, a DELETE statement is emitted for this mapper's + table individually. If the primary key attributes local to this + table are unloaded, then a SELECT must be emitted in order to + validate these attributes; note that the primary key columns + of a joined-table subclass are not part of the "primary key" of + the object as a whole. + + Note that a value of ``True`` is **always** forced onto the + subclass mappers; that is, it's not possible for a superclass + to specify passive_deletes without this taking effect for + all subclass mappers. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`passive_deletes` - description of similar feature as + used with :func:`.relationship` + + :paramref:`.mapper.passive_updates` - supporting ON UPDATE + CASCADE for joined-table inheritance mappers + + :param passive_updates: Indicates UPDATE behavior of foreign key + columns when a primary key column changes on a joined-table + inheritance mapping. Defaults to ``True``. + + When True, it is assumed that ON UPDATE CASCADE is configured on + the foreign key in the database, and that the database will handle + propagation of an UPDATE from a source column to dependent columns + on joined-table rows. + + When False, it is assumed that the database does not enforce + referential integrity and will not be issuing its own CASCADE + operation for an update. The unit of work process will + emit an UPDATE statement for the dependent columns during a + primary key change. + + .. seealso:: + + :ref:`passive_updates` - description of a similar feature as + used with :func:`.relationship` + + :paramref:`.mapper.passive_deletes` - supporting ON DELETE + CASCADE for joined-table inheritance mappers + + :param polymorphic_on: Specifies the column, attribute, or + SQL expression used to determine the target class for an + incoming row, when inheriting classes are present. + + This value is commonly a :class:`.Column` object that's + present in the mapped :class:`.Table`:: + + class Employee(Base): + __tablename__ = 'employee' + + id = Column(Integer, primary_key=True) + discriminator = Column(String(50)) + + __mapper_args__ = { + "polymorphic_on":discriminator, + "polymorphic_identity":"employee" + } + + It may also be specified + as a SQL expression, as in this example where we + use the :func:`.case` construct to provide a conditional + approach:: + + class Employee(Base): + __tablename__ = 'employee' + + id = Column(Integer, primary_key=True) + discriminator = Column(String(50)) + + __mapper_args__ = { + "polymorphic_on":case([ + (discriminator == "EN", "engineer"), + (discriminator == "MA", "manager"), + ], else_="employee"), + "polymorphic_identity":"employee" + } + + It may also refer to any attribute + configured with :func:`.column_property`, or to the + string name of one:: + + class Employee(Base): + __tablename__ = 'employee' + + id = Column(Integer, primary_key=True) + discriminator = Column(String(50)) + employee_type = column_property( + case([ + (discriminator == "EN", "engineer"), + (discriminator == "MA", "manager"), + ], else_="employee") + ) + + __mapper_args__ = { + "polymorphic_on":employee_type, + "polymorphic_identity":"employee" + } + + .. versionchanged:: 0.7.4 + ``polymorphic_on`` may be specified as a SQL expression, + or refer to any attribute configured with + :func:`.column_property`, or to the string name of one. + + When setting ``polymorphic_on`` to reference an + attribute or expression that's not present in the + locally mapped :class:`.Table`, yet the value + of the discriminator should be persisted to the database, + the value of the + discriminator is not automatically set on new + instances; this must be handled by the user, + either through manual means or via event listeners. + A typical approach to establishing such a listener + looks like:: + + from sqlalchemy import event + from sqlalchemy.orm import object_mapper + + @event.listens_for(Employee, "init", propagate=True) + def set_identity(instance, *arg, **kw): + mapper = object_mapper(instance) + instance.discriminator = mapper.polymorphic_identity + + Where above, we assign the value of ``polymorphic_identity`` + for the mapped class to the ``discriminator`` attribute, + thus persisting the value to the ``discriminator`` column + in the database. + + .. warning:: + + Currently, **only one discriminator column may be set**, typically + on the base-most class in the hierarchy. "Cascading" polymorphic + columns are not yet supported. + + .. seealso:: + + :ref:`inheritance_toplevel` + + :param polymorphic_identity: Specifies the value which + identifies this particular class as returned by the + column expression referred to by the ``polymorphic_on`` + setting. As rows are received, the value corresponding + to the ``polymorphic_on`` column expression is compared + to this value, indicating which subclass should + be used for the newly reconstructed object. + + :param properties: A dictionary mapping the string names of object + attributes to :class:`.MapperProperty` instances, which define the + persistence behavior of that attribute. Note that :class:`.Column` + objects present in + the mapped :class:`.Table` are automatically placed into + ``ColumnProperty`` instances upon mapping, unless overridden. + When using Declarative, this argument is passed automatically, + based on all those :class:`.MapperProperty` instances declared + in the declared class body. + + :param primary_key: A list of :class:`.Column` objects which define + the primary key to be used against this mapper's selectable unit. + This is normally simply the primary key of the ``local_table``, but + can be overridden here. + + :param version_id_col: A :class:`.Column` + that will be used to keep a running version id of rows + in the table. This is used to detect concurrent updates or + the presence of stale data in a flush. The methodology is to + detect if an UPDATE statement does not match the last known + version id, a + :class:`~sqlalchemy.orm.exc.StaleDataError` exception is + thrown. + By default, the column must be of :class:`.Integer` type, + unless ``version_id_generator`` specifies an alternative version + generator. + + .. seealso:: + + :ref:`mapper_version_counter` - discussion of version counting + and rationale. + + :param version_id_generator: Define how new version ids should + be generated. Defaults to ``None``, which indicates that + a simple integer counting scheme be employed. To provide a custom + versioning scheme, provide a callable function of the form:: + + def generate_version(version): + return next_version + + Alternatively, server-side versioning functions such as triggers, + or programmatic versioning schemes outside of the version id + generator may be used, by specifying the value ``False``. + Please see :ref:`server_side_version_counter` for a discussion + of important points when using this option. + + .. versionadded:: 0.9.0 ``version_id_generator`` supports + server-side version number generation. + + .. seealso:: + + :ref:`custom_version_counter` + + :ref:`server_side_version_counter` + + + :param with_polymorphic: A tuple in the form ``(, + )`` indicating the default style of "polymorphic" + loading, that is, which tables are queried at once. is + any single or list of mappers and/or classes indicating the + inherited classes that should be loaded at once. The special value + ``'*'`` may be used to indicate all descending classes should be + loaded immediately. The second tuple argument + indicates a selectable that will be used to query for multiple + classes. + + .. seealso:: + + :ref:`with_polymorphic` - discussion of polymorphic querying + techniques. """ @@ -109,124 +563,405 @@ class Mapper(object): self.class_manager = None - self.primary_key_argument = primary_key + self._primary_key_argument = util.to_list(primary_key) self.non_primary = non_primary if order_by is not False: self.order_by = util.to_list(order_by) + util.warn_deprecated( + "Mapper.order_by is deprecated." + "Use Query.order_by() in order to affect the ordering of ORM " + "result sets.") + else: self.order_by = order_by - + self.always_refresh = always_refresh - self.version_id_col = version_id_col - self.version_id_generator = version_id_generator or (lambda x:(x or 0) + 1) + + if isinstance(version_id_col, MapperProperty): + self.version_id_prop = version_id_col + self.version_id_col = None + else: + self.version_id_col = version_id_col + if version_id_generator is False: + self.version_id_generator = False + elif version_id_generator is None: + self.version_id_generator = lambda x: (x or 0) + 1 + else: + self.version_id_generator = version_id_generator + self.concrete = concrete self.single = False self.inherits = inherits self.local_table = local_table self.inherit_condition = inherit_condition self.inherit_foreign_keys = inherit_foreign_keys - self.extension = extension self._init_properties = properties or {} - self.delete_orphans = [] + self._delete_orphans = [] self.batch = batch self.eager_defaults = eager_defaults self.column_prefix = column_prefix - self.polymorphic_on = polymorphic_on + self.polymorphic_on = expression._clause_element_as_expr( + polymorphic_on) self._dependency_processors = [] - self._validators = {} + self.validators = util.immutabledict() self.passive_updates = passive_updates + self.passive_deletes = passive_deletes + self.legacy_is_orphan = legacy_is_orphan self._clause_adapter = None self._requires_row_aliasing = False self._inherits_equated_pairs = None - - if allow_null_pks: - util.warn_deprecated('the allow_null_pks option to Mapper() is ' - 'deprecated. It is now allow_partial_pks=False|True, ' - 'defaults to True.') - allow_partial_pks = allow_null_pks - + self._memoized_values = {} + self._compiled_cache_size = _compiled_cache_size + self._reconstructor = None + self._deprecated_extensions = util.to_list(extension or []) self.allow_partial_pks = allow_partial_pks - - if with_polymorphic == '*': - self.with_polymorphic = ('*', None) - elif isinstance(with_polymorphic, (tuple, list)): - if isinstance(with_polymorphic[0], (basestring, tuple, list)): - self.with_polymorphic = with_polymorphic - else: - self.with_polymorphic = (with_polymorphic, None) - elif with_polymorphic is not None: - raise sa_exc.ArgumentError("Invalid setting for with_polymorphic") - else: - self.with_polymorphic = None - if isinstance(self.local_table, expression._SelectBaseMixin): + if self.inherits and not self.concrete: + self.confirm_deleted_rows = False + else: + self.confirm_deleted_rows = confirm_deleted_rows + + self._set_with_polymorphic(with_polymorphic) + + if isinstance(self.local_table, expression.SelectBase): raise sa_exc.InvalidRequestError( "When mapping against a select() construct, map against " "an alias() of the construct instead." "This because several databases don't allow a " "SELECT from a subquery that does not have an alias." - ) + ) if self.with_polymorphic and \ - isinstance(self.with_polymorphic[1], expression._SelectBaseMixin): - self.with_polymorphic = (self.with_polymorphic[0], self.with_polymorphic[1].alias()) + isinstance(self.with_polymorphic[1], + expression.SelectBase): + self.with_polymorphic = (self.with_polymorphic[0], + self.with_polymorphic[1].alias()) - # our 'polymorphic identity', a string name that when located in a result set row - # indicates this Mapper should be used to construct the object instance for that row. + # our 'polymorphic identity', a string name that when located in a + # result set row indicates this Mapper should be used to construct + # the object instance for that row. self.polymorphic_identity = polymorphic_identity - # a dictionary of 'polymorphic identity' names, associating those names with - # Mappers that will be used to construct object instances upon a select operation. + # a dictionary of 'polymorphic identity' names, associating those + # names with Mappers that will be used to construct object instances + # upon a select operation. if _polymorphic_map is None: self.polymorphic_map = {} else: self.polymorphic_map = _polymorphic_map - self.include_properties = include_properties - self.exclude_properties = exclude_properties + if include_properties is not None: + self.include_properties = util.to_set(include_properties) + else: + self.include_properties = None + if exclude_properties: + self.exclude_properties = util.to_set(exclude_properties) + else: + self.exclude_properties = None + + self.configured = False - self.compiled = False - # prevent this mapper from being constructed - # while a compile() is occuring (and defer a compile() - # until construction succeeds) - _COMPILE_MUTEX.acquire() + # while a configure_mappers() is occurring (and defer a + # configure_mappers() until construction succeeds) + _CONFIGURE_MUTEX.acquire() try: + self.dispatch._events._new_mapper_instance(class_, self) self._configure_inheritance() - self._configure_extensions() + self._configure_legacy_instrument_class() self._configure_class_instrumentation() + self._configure_listeners() self._configure_properties() + self._configure_polymorphic_setter() self._configure_pks() - global _new_mappers - _new_mappers = True + Mapper._new_mappers = True self._log("constructed") + self._expire_memoizations() finally: - _COMPILE_MUTEX.release() - + _CONFIGURE_MUTEX.release() + + # major attributes initialized at the classlevel so that + # they can be Sphinx-documented. + + is_mapper = True + """Part of the inspection API.""" + + @property + def mapper(self): + """Part of the inspection API. + + Returns self. + + """ + return self + + @property + def entity(self): + r"""Part of the inspection API. + + Returns self.class\_. + + """ + return self.class_ + + local_table = None + """The :class:`.Selectable` which this :class:`.Mapper` manages. + + Typically is an instance of :class:`.Table` or :class:`.Alias`. + May also be ``None``. + + The "local" table is the + selectable that the :class:`.Mapper` is directly responsible for + managing from an attribute access and flush perspective. For + non-inheriting mappers, the local table is the same as the + "mapped" table. For joined-table inheritance mappers, local_table + will be the particular sub-table of the overall "join" which + this :class:`.Mapper` represents. If this mapper is a + single-table inheriting mapper, local_table will be ``None``. + + .. seealso:: + + :attr:`~.Mapper.mapped_table`. + + """ + + mapped_table = None + """The :class:`.Selectable` to which this :class:`.Mapper` is mapped. + + Typically an instance of :class:`.Table`, :class:`.Join`, or + :class:`.Alias`. + + The "mapped" table is the selectable that + the mapper selects from during queries. For non-inheriting + mappers, the mapped table is the same as the "local" table. + For joined-table inheritance mappers, mapped_table references the + full :class:`.Join` representing full rows for this particular + subclass. For single-table inheritance mappers, mapped_table + references the base table. + + .. seealso:: + + :attr:`~.Mapper.local_table`. + + """ + + inherits = None + """References the :class:`.Mapper` which this :class:`.Mapper` + inherits from, if any. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + configured = None + """Represent ``True`` if this :class:`.Mapper` has been configured. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + .. seealso:: + + :func:`.configure_mappers`. + + """ + + concrete = None + """Represent ``True`` if this :class:`.Mapper` is a concrete + inheritance mapper. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + tables = None + """An iterable containing the collection of :class:`.Table` objects + which this :class:`.Mapper` is aware of. + + If the mapper is mapped to a :class:`.Join`, or an :class:`.Alias` + representing a :class:`.Select`, the individual :class:`.Table` + objects that comprise the full construct will be represented here. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + primary_key = None + """An iterable containing the collection of :class:`.Column` objects + which comprise the 'primary key' of the mapped table, from the + perspective of this :class:`.Mapper`. + + This list is against the selectable in :attr:`~.Mapper.mapped_table`. In + the case of inheriting mappers, some columns may be managed by a + superclass mapper. For example, in the case of a :class:`.Join`, the + primary key is determined by all of the primary key columns across all + tables referenced by the :class:`.Join`. + + The list is also not necessarily the same as the primary key column + collection associated with the underlying tables; the :class:`.Mapper` + features a ``primary_key`` argument that can override what the + :class:`.Mapper` considers as primary key columns. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + class_ = None + """The Python class which this :class:`.Mapper` maps. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + class_manager = None + """The :class:`.ClassManager` which maintains event listeners + and class-bound descriptors for this :class:`.Mapper`. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + single = None + """Represent ``True`` if this :class:`.Mapper` is a single table + inheritance mapper. + + :attr:`~.Mapper.local_table` will be ``None`` if this flag is set. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + non_primary = None + """Represent ``True`` if this :class:`.Mapper` is a "non-primary" + mapper, e.g. a mapper that is used only to selet rows but not for + persistence management. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + polymorphic_on = None + """The :class:`.Column` or SQL expression specified as the + ``polymorphic_on`` argument + for this :class:`.Mapper`, within an inheritance scenario. + + This attribute is normally a :class:`.Column` instance but + may also be an expression, such as one derived from + :func:`.cast`. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + polymorphic_map = None + """A mapping of "polymorphic identity" identifiers mapped to + :class:`.Mapper` instances, within an inheritance scenario. + + The identifiers can be of any type which is comparable to the + type of column represented by :attr:`~.Mapper.polymorphic_on`. + + An inheritance chain of mappers will all reference the same + polymorphic map object. The object is used to correlate incoming + result rows to target mappers. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + polymorphic_identity = None + """Represent an identifier which is matched against the + :attr:`~.Mapper.polymorphic_on` column during result row loading. + + Used only with inheritance, this object can be of any type which is + comparable to the type of column represented by + :attr:`~.Mapper.polymorphic_on`. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + base_mapper = None + """The base-most :class:`.Mapper` in an inheritance chain. + + In a non-inheriting scenario, this attribute will always be this + :class:`.Mapper`. In an inheritance scenario, it references + the :class:`.Mapper` which is parent to all other :class:`.Mapper` + objects in the inheritance chain. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + columns = None + """A collection of :class:`.Column` or other scalar expression + objects maintained by this :class:`.Mapper`. + + The collection behaves the same as that of the ``c`` attribute on + any :class:`.Table` object, except that only those columns included in + this mapping are present, and are keyed based on the attribute name + defined in the mapping, not necessarily the ``key`` attribute of the + :class:`.Column` itself. Additionally, scalar expressions mapped + by :func:`.column_property` are also present here. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + validators = None + """An immutable dictionary of attributes which have been decorated + using the :func:`~.orm.validates` decorator. + + The dictionary contains string attribute names as keys + mapped to the actual validation method. + + """ + + c = None + """A synonym for :attr:`~.Mapper.columns`.""" + + @util.memoized_property + def _path_registry(self): + return PathRegistry.per_mapper(self) + def _configure_inheritance(self): - """Configure settings related to inherting and/or inherited mappers being present.""" + """Configure settings related to inherting and/or inherited mappers + being present.""" # a set of all mappers which inherit from this one. - self._inheriting_mappers = set() + self._inheriting_mappers = util.WeakSequence() if self.inherits: if isinstance(self.inherits, type): - self.inherits = class_mapper(self.inherits, compile=False) + self.inherits = class_mapper(self.inherits, configure=False) if not issubclass(self.class_, self.inherits.class_): raise sa_exc.ArgumentError( - "Class '%s' does not inherit from '%s'" % - (self.class_.__name__, self.inherits.class_.__name__)) + "Class '%s' does not inherit from '%s'" % + (self.class_.__name__, self.inherits.class_.__name__)) if self.non_primary != self.inherits.non_primary: np = not self.non_primary and "primary" or "non-primary" - raise sa_exc.ArgumentError("Inheritance of %s mapper for class '%s' is " - "only allowed from a %s mapper" % (np, self.class_.__name__, np)) + raise sa_exc.ArgumentError( + "Inheritance of %s mapper for class '%s' is " + "only allowed from a %s mapper" % + (np, self.class_.__name__, np)) # inherit_condition is optional. if self.local_table is None: self.local_table = self.inherits.local_table self.mapped_table = self.inherits.mapped_table self.single = True - elif not self.local_table is self.inherits.local_table: + elif self.local_table is not self.inherits.local_table: if self.concrete: self.mapped_table = self.local_table for mapper in self.iterate_to_root(): @@ -234,16 +969,23 @@ class Mapper(object): mapper._requires_row_aliasing = True else: if self.inherit_condition is None: - # figure out inherit condition from our table to the immediate table - # of the inherited mapper, not its full table which could pull in other - # stuff we dont want (allows test/inheritance.InheritTest4 to pass) - self.inherit_condition = sqlutil.join_condition(self.inherits.local_table, self.local_table) - self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition) + # figure out inherit condition from our table to the + # immediate table of the inherited mapper, not its + # full table which could pull in other stuff we don't + # want (allows test/inheritance.InheritTest4 to pass) + self.inherit_condition = sql_util.join_condition( + self.inherits.local_table, + self.local_table) + self.mapped_table = sql.join( + self.inherits.mapped_table, + self.local_table, + self.inherit_condition) fks = util.to_set(self.inherit_foreign_keys) self._inherits_equated_pairs = \ - sqlutil.criterion_as_pairs(self.mapped_table.onclause, - consider_as_foreign_keys=fks) + sql_util.criterion_as_pairs( + self.mapped_table.onclause, + consider_as_foreign_keys=fks) else: self.mapped_table = self.local_table @@ -255,30 +997,44 @@ class Mapper(object): if self.version_id_col is None: self.version_id_col = self.inherits.version_id_col self.version_id_generator = self.inherits.version_id_generator + elif self.inherits.version_id_col is not None and \ + self.version_id_col is not self.inherits.version_id_col: + util.warn( + "Inheriting version_id_col '%s' does not match inherited " + "version_id_col '%s' and will not automatically populate " + "the inherited versioning column. " + "version_id_col should only be specified on " + "the base-most mapper that includes versioning." % + (self.version_id_col.description, + self.inherits.version_id_col.description) + ) - for mapper in self.iterate_to_root(): - util.reset_memoized(mapper, '_equivalent_columns') - util.reset_memoized(mapper, '_sorted_tables') - - if self.order_by is False and not self.concrete and self.inherits.order_by is not False: + if self.order_by is False and \ + not self.concrete and \ + self.inherits.order_by is not False: self.order_by = self.inherits.order_by self.polymorphic_map = self.inherits.polymorphic_map self.batch = self.inherits.batch - self.inherits._inheriting_mappers.add(self) + self.inherits._inheriting_mappers.append(self) self.base_mapper = self.inherits.base_mapper self.passive_updates = self.inherits.passive_updates + self.passive_deletes = self.inherits.passive_deletes or \ + self.passive_deletes self._all_tables = self.inherits._all_tables if self.polymorphic_identity is not None: + if self.polymorphic_identity in self.polymorphic_map: + util.warn( + "Reassigning polymorphic association for identity %r " + "from %r to %r: Check for duplicate use of %r as " + "value for polymorphic_identity." % + (self.polymorphic_identity, + self.polymorphic_map[self.polymorphic_identity], + self, self.polymorphic_identity) + ) self.polymorphic_map[self.polymorphic_identity] = self - if self.polymorphic_on is None: - for mapper in self.iterate_to_root(): - # try to set up polymorphic on using correesponding_column(); else leave - # as None - if mapper.polymorphic_on is not None: - self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on) - break + else: self._all_tables = set() self.base_mapper = self @@ -288,36 +1044,95 @@ class Mapper(object): self._identity_class = self.class_ if self.mapped_table is None: - raise sa_exc.ArgumentError("Mapper '%s' does not have a mapped_table specified." % self) + raise sa_exc.ArgumentError( + "Mapper '%s' does not have a mapped_table specified." + % self) - def _configure_extensions(self): - """Go through the global_extensions list as well as the list - of ``MapperExtensions`` specified for this ``Mapper`` and - creates a linked list of those extensions. - - """ - extlist = util.OrderedSet() + def _set_with_polymorphic(self, with_polymorphic): + if with_polymorphic == '*': + self.with_polymorphic = ('*', None) + elif isinstance(with_polymorphic, (tuple, list)): + if isinstance( + with_polymorphic[0], util.string_types + (tuple, list)): + self.with_polymorphic = with_polymorphic + else: + self.with_polymorphic = (with_polymorphic, None) + elif with_polymorphic is not None: + raise sa_exc.ArgumentError("Invalid setting for with_polymorphic") + else: + self.with_polymorphic = None - extension = self.extension - if extension: - for ext_obj in util.to_list(extension): - # local MapperExtensions have already instrumented the class - extlist.add(ext_obj) + if isinstance(self.local_table, expression.SelectBase): + raise sa_exc.InvalidRequestError( + "When mapping against a select() construct, map against " + "an alias() of the construct instead." + "This because several databases don't allow a " + "SELECT from a subquery that does not have an alias." + ) + + if self.with_polymorphic and \ + isinstance(self.with_polymorphic[1], + expression.SelectBase): + self.with_polymorphic = (self.with_polymorphic[0], + self.with_polymorphic[1].alias()) + if self.configured: + self._expire_memoizations() + + def _set_concrete_base(self, mapper): + """Set the given :class:`.Mapper` as the 'inherits' for this + :class:`.Mapper`, assuming this :class:`.Mapper` is concrete + and does not already have an inherits.""" + + assert self.concrete + assert not self.inherits + assert isinstance(mapper, Mapper) + self.inherits = mapper + self.inherits.polymorphic_map.update(self.polymorphic_map) + self.polymorphic_map = self.inherits.polymorphic_map + for mapper in self.iterate_to_root(): + if mapper.polymorphic_on is not None: + mapper._requires_row_aliasing = True + self.batch = self.inherits.batch + for mp in self.self_and_descendants: + mp.base_mapper = self.inherits.base_mapper + self.inherits._inheriting_mappers.append(self) + self.passive_updates = self.inherits.passive_updates + self._all_tables = self.inherits._all_tables + for key, prop in mapper._props.items(): + if key not in self._props and \ + not self._should_exclude(key, key, local=False, + column=None): + self._adapt_inherited_property(key, prop, False) + + def _set_polymorphic_on(self, polymorphic_on): + self.polymorphic_on = polymorphic_on + self._configure_polymorphic_setter(True) + + def _configure_legacy_instrument_class(self): if self.inherits: - for ext in self.inherits.extension: - if ext not in extlist: - extlist.add(ext) + self.dispatch._update(self.inherits.dispatch) + super_extensions = set( + chain(*[m._deprecated_extensions + for m in self.inherits.iterate_to_root()])) else: - for ext in global_extensions: - if isinstance(ext, type): - ext = ext() - if ext not in extlist: - extlist.add(ext) + super_extensions = set() - self.extension = ExtensionCarrier() - for ext in extlist: - self.extension.append(ext) + for ext in self._deprecated_extensions: + if ext not in super_extensions: + ext._adapt_instrument_class(self, ext) + + def _configure_listeners(self): + if self.inherits: + super_extensions = set( + chain(*[m._deprecated_extensions + for m in self.inherits.iterate_to_root()])) + else: + super_extensions = set() + + for ext in self._deprecated_extensions: + if ext not in super_extensions: + ext._adapt_listener(self, ext) def _configure_class_instrumentation(self): """If this mapper is to be a primary mapper (i.e. the @@ -330,90 +1145,105 @@ class Mapper(object): auto-session attachment logic. """ + manager = attributes.manager_of_class(self.class_) - + if self.non_primary: - if not manager or manager.mapper is None: + if not manager or not manager.is_mapped: raise sa_exc.InvalidRequestError( "Class %s has no primary mapper configured. Configure " "a primary mapper first before setting up a non primary " - "Mapper.") + "Mapper." % self.class_) self.class_manager = manager + self._identity_class = manager.mapper._identity_class _mapper_registry[self] = True return if manager is not None: assert manager.class_ is self.class_ - if manager.mapper: + if manager.is_mapped: raise sa_exc.ArgumentError( "Class '%s' already has a primary mapper defined. " "Use non_primary=True to " "create a non primary Mapper. clear_mappers() will " "remove *all* current mappers from all classes." % self.class_) - #else: - # a ClassManager may already exist as - # ClassManager.instrument_attribute() creates + # else: + # a ClassManager may already exist as + # ClassManager.instrument_attribute() creates # new managers for each subclass if they don't yet exist. - + _mapper_registry[self] = True - self.extension.instrument_class(self, self.class_) + # note: this *must be called before instrumentation.register_class* + # to maintain the documented behavior of instrument_class + self.dispatch.instrument_class(self, self.class_) if manager is None: - manager = attributes.register_class(self.class_, - deferred_scalar_loader = _load_scalar_attributes - ) + manager = instrumentation.register_class(self.class_) self.class_manager = manager manager.mapper = self + manager.deferred_scalar_loader = util.partial( + loading.load_scalar_attributes, self) - # The remaining members can be added by any mapper, e_name None or not. + # The remaining members can be added by any mapper, + # e_name None or not. if manager.info.get(_INSTRUMENTOR, False): return - event_registry = manager.events - event_registry.add_listener('on_init', _event_on_init) - event_registry.add_listener('on_init_failure', _event_on_init_failure) - event_registry.add_listener('on_resurrect', _event_on_resurrect) - + event.listen(manager, 'first_init', _event_on_first_init, raw=True) + event.listen(manager, 'init', _event_on_init, raw=True) + for key, method in util.iterate_attributes(self.class_): if isinstance(method, types.FunctionType): if hasattr(method, '__sa_reconstructor__'): - event_registry.add_listener('on_load', method) + self._reconstructor = method + event.listen(manager, 'load', _event_on_load, raw=True) elif hasattr(method, '__sa_validators__'): + validation_opts = method.__sa_validation_opts__ for name in method.__sa_validators__: - self._validators[name] = method - - if 'reconstruct_instance' in self.extension: - def reconstruct(instance): - self.extension.reconstruct_instance(self, instance) - event_registry.add_listener('on_load', reconstruct) + if name in self.validators: + raise sa_exc.InvalidRequestError( + "A validation function for mapped " + "attribute %r on mapper %s already exists." % + (name, self)) + self.validators = self.validators.union( + {name: (method, validation_opts)} + ) manager.info[_INSTRUMENTOR] = self + @classmethod + def _configure_all(cls): + """Class-level path to the :func:`.configure_mappers` call. + """ + configure_mappers() + def dispose(self): # Disable any attribute-based compilation. - self.compiled = True - - if hasattr(self, '_compile_failed'): - del self._compile_failed - - if not self.non_primary and self.class_manager.mapper is self: - attributes.unregister_class(self.class_) + self.configured = True + + if hasattr(self, '_configure_failed'): + del self._configure_failed + + if not self.non_primary and \ + self.class_manager is not None and \ + self.class_manager.is_mapped and \ + self.class_manager.mapper is self: + instrumentation.unregister_class(self.class_) def _configure_pks(self): - - self.tables = sqlutil.find_tables(self.mapped_table) - - if not self.tables: - raise sa_exc.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table)) + self.tables = sql_util.find_tables(self.mapped_table) self._pks_by_table = {} self._cols_by_table = {} - all_cols = util.column_set(chain(*[col.proxy_set for col in self._columntoproperty])) + all_cols = util.column_set(chain(*[ + col.proxy_set for col in + self._columntoproperty])) + pk_cols = util.column_set(c for c in all_cols if c.primary_key) # identify primary key columns which are also mapped by this mapper. @@ -421,50 +1251,75 @@ class Mapper(object): self._all_tables.update(tables) for t in tables: if t.primary_key and pk_cols.issuperset(t.primary_key): - # ordering is important since it determines the ordering of mapper.primary_key (and therefore query.get()) - self._pks_by_table[t] = util.ordered_column_set(t.primary_key).intersection(pk_cols) - self._cols_by_table[t] = util.ordered_column_set(t.c).intersection(all_cols) + # ordering is important since it determines the ordering of + # mapper.primary_key (and therefore query.get()) + self._pks_by_table[t] = \ + util.ordered_column_set(t.primary_key).\ + intersection(pk_cols) + self._cols_by_table[t] = \ + util.ordered_column_set(t.c).\ + intersection(all_cols) + + # if explicit PK argument sent, add those columns to the + # primary key mappings + if self._primary_key_argument: + for k in self._primary_key_argument: + if k.table not in self._pks_by_table: + self._pks_by_table[k.table] = util.OrderedSet() + self._pks_by_table[k.table].add(k) + + # otherwise, see that we got a full PK for the mapped table + elif self.mapped_table not in self._pks_by_table or \ + len(self._pks_by_table[self.mapped_table]) == 0: + raise sa_exc.ArgumentError( + "Mapper %s could not assemble any primary " + "key columns for mapped table '%s'" % + (self, self.mapped_table.description)) + elif self.local_table not in self._pks_by_table and \ + isinstance(self.local_table, schema.Table): + util.warn("Could not assemble any primary " + "keys for locally mapped table '%s' - " + "no rows will be persisted in this Table." + % self.local_table.description) + + if self.inherits and \ + not self.concrete and \ + not self._primary_key_argument: + # if inheriting, the "primary key" for this mapper is + # that of the inheriting (unless concrete or explicit) + self.primary_key = self.inherits.primary_key + else: + # determine primary key from argument or mapped_table pks - + # reduce to the minimal set of columns + if self._primary_key_argument: + primary_key = sql_util.reduce_columns( + [self.mapped_table.corresponding_column(c) for c in + self._primary_key_argument], + ignore_nonexistent_tables=True) + else: + primary_key = sql_util.reduce_columns( + self._pks_by_table[self.mapped_table], + ignore_nonexistent_tables=True) + + if len(primary_key) == 0: + raise sa_exc.ArgumentError( + "Mapper %s could not assemble any primary " + "key columns for mapped table '%s'" % + (self, self.mapped_table.description)) + + self.primary_key = tuple(primary_key) + self._log("Identified primary key columns: %s", primary_key) # determine cols that aren't expressed within our tables; mark these # as "read only" properties which are refreshed upon INSERT/UPDATE self._readonly_props = set( self._columntoproperty[col] for col in self._columntoproperty - if not hasattr(col, 'table') or col.table not in self._cols_by_table) - - # if explicit PK argument sent, add those columns to the primary key mappings - if self.primary_key_argument: - for k in self.primary_key_argument: - if k.table not in self._pks_by_table: - self._pks_by_table[k.table] = util.OrderedSet() - self._pks_by_table[k.table].add(k) - - if self.mapped_table not in self._pks_by_table or len(self._pks_by_table[self.mapped_table]) == 0: - raise sa_exc.ArgumentError("Mapper %s could not assemble any primary " - "key columns for mapped table '%s'" % (self, self.mapped_table.description)) - - if self.inherits and not self.concrete and not self.primary_key_argument: - # if inheriting, the "primary key" for this mapper is that of the inheriting (unless concrete or explicit) - self.primary_key = self.inherits.primary_key - else: - # determine primary key from argument or mapped_table pks - reduce to the minimal set of columns - if self.primary_key_argument: - primary_key = sqlutil.reduce_columns( - [self.mapped_table.corresponding_column(c) for c in self.primary_key_argument], - ignore_nonexistent_tables=True) - else: - primary_key = sqlutil.reduce_columns( - self._pks_by_table[self.mapped_table], ignore_nonexistent_tables=True) - - if len(primary_key) == 0: - raise sa_exc.ArgumentError("Mapper %s could not assemble any primary " - "key columns for mapped table '%s'" % (self, self.mapped_table.description)) - - self.primary_key = primary_key - self._log("Identified primary key columns: %s", primary_key) + if self._columntoproperty[col] not in self._identity_key_props and + (not hasattr(col, 'table') or + col.table not in self._cols_by_table)) def _configure_properties(self): - # Column and other ClauseElement objects which are mapped self.columns = self.c = util.OrderedProperties() @@ -474,17 +1329,19 @@ class Mapper(object): # table columns mapped to lists of MapperProperty objects # using a list allows a single column to be defined as # populating multiple object attributes - self._columntoproperty = util.column_dict() + self._columntoproperty = _ColumnMapping(self) # load custom properties if self._init_properties: - for key, prop in self._init_properties.iteritems(): + for key, prop in self._init_properties.items(): self._configure_property(key, prop, False) # pull properties from the inherited mapper if any. if self.inherits: - for key, prop in self.inherits._props.iteritems(): - if key not in self._props and not self._should_exclude(key, key, local=False): + for key, prop in self.inherits._props.items(): + if key not in self._props and \ + not self._should_exclude(key, key, local=False, + column=None): self._adapt_inherited_property(key, prop, False) # create properties for each column in the mapped table, @@ -495,7 +1352,11 @@ class Mapper(object): column_key = (self.column_prefix or '') + column.key - if self._should_exclude(column.key, column_key, local=self.local_table.c.contains_column(column)): + if self._should_exclude( + column.key, column_key, + local=self.local_table.c.contains_column(column), + column=column + ): continue # adjust the "key" used for this column to that @@ -504,82 +1365,221 @@ class Mapper(object): if column in mapper._columntoproperty: column_key = mapper._columntoproperty[column].key - self._configure_property(column_key, column, init=False, setparent=True) + self._configure_property(column_key, + column, + init=False, + setparent=True) - # do a special check for the "discriminiator" column, as it may only be present - # in the 'with_polymorphic' selectable but we need it for the base mapper - if self.polymorphic_on is not None and self.polymorphic_on not in self._columntoproperty: - col = self.mapped_table.corresponding_column(self.polymorphic_on) - if col is None: - instrument = False - col = self.polymorphic_on + def _configure_polymorphic_setter(self, init=False): + """Configure an attribute on the mapper representing the + 'polymorphic_on' column, if applicable, and not + already generated by _configure_properties (which is typical). + + Also create a setter function which will assign this + attribute to the value of the 'polymorphic_identity' + upon instance construction, also if applicable. This + routine will run when an instance is created. + + """ + setter = False + + if self.polymorphic_on is not None: + setter = True + + if isinstance(self.polymorphic_on, util.string_types): + # polymorphic_on specified as a string - link + # it to mapped ColumnProperty + try: + self.polymorphic_on = self._props[self.polymorphic_on] + except KeyError: + raise sa_exc.ArgumentError( + "Can't determine polymorphic_on " + "value '%s' - no attribute is " + "mapped to this name." % self.polymorphic_on) + + if self.polymorphic_on in self._columntoproperty: + # polymorphic_on is a column that is already mapped + # to a ColumnProperty + prop = self._columntoproperty[self.polymorphic_on] + elif isinstance(self.polymorphic_on, MapperProperty): + # polymorphic_on is directly a MapperProperty, + # ensure it's a ColumnProperty + if not isinstance(self.polymorphic_on, + properties.ColumnProperty): + raise sa_exc.ArgumentError( + "Only direct column-mapped " + "property or SQL expression " + "can be passed for polymorphic_on") + prop = self.polymorphic_on + elif not expression._is_column(self.polymorphic_on): + # polymorphic_on is not a Column and not a ColumnProperty; + # not supported right now. + raise sa_exc.ArgumentError( + "Only direct column-mapped " + "property or SQL expression " + "can be passed for polymorphic_on" + ) else: - instrument = True - if self._should_exclude(col.key, col.key, local=False): - raise sa_exc.InvalidRequestError("Cannot exclude or override the discriminator column %r" % col.key) - self._configure_property(col.key, ColumnProperty(col, _instrument=instrument), init=False, setparent=True) + # polymorphic_on is a Column or SQL expression and + # doesn't appear to be mapped. this means it can be 1. + # only present in the with_polymorphic selectable or + # 2. a totally standalone SQL expression which we'd + # hope is compatible with this mapper's mapped_table + col = self.mapped_table.corresponding_column( + self.polymorphic_on) + if col is None: + # polymorphic_on doesn't derive from any + # column/expression isn't present in the mapped + # table. we will make a "hidden" ColumnProperty + # for it. Just check that if it's directly a + # schema.Column and we have with_polymorphic, it's + # likely a user error if the schema.Column isn't + # represented somehow in either mapped_table or + # with_polymorphic. Otherwise as of 0.7.4 we + # just go with it and assume the user wants it + # that way (i.e. a CASE statement) + setter = False + instrument = False + col = self.polymorphic_on + if isinstance(col, schema.Column) and ( + self.with_polymorphic is None or + self.with_polymorphic[1]. + corresponding_column(col) is None): + raise sa_exc.InvalidRequestError( + "Could not map polymorphic_on column " + "'%s' to the mapped table - polymorphic " + "loads will not function properly" + % col.description) + else: + # column/expression that polymorphic_on derives from + # is present in our mapped table + # and is probably mapped, but polymorphic_on itself + # is not. This happens when + # the polymorphic_on is only directly present in the + # with_polymorphic selectable, as when use + # polymorphic_union. + # we'll make a separate ColumnProperty for it. + instrument = True + key = getattr(col, 'key', None) + if key: + if self._should_exclude(col.key, col.key, False, col): + raise sa_exc.InvalidRequestError( + "Cannot exclude or override the " + "discriminator column %r" % + col.key) + else: + self.polymorphic_on = col = \ + col.label("_sa_polymorphic_on") + key = col.key + + prop = properties.ColumnProperty(col, _instrument=instrument) + self._configure_property(key, prop, init=init, setparent=True) + + # the actual polymorphic_on should be the first public-facing + # column in the property + self.polymorphic_on = prop.columns[0] + polymorphic_key = prop.key + + else: + # no polymorphic_on was set. + # check inheriting mappers for one. + for mapper in self.iterate_to_root(): + # determine if polymorphic_on of the parent + # should be propagated here. If the col + # is present in our mapped table, or if our mapped + # table is the same as the parent (i.e. single table + # inheritance), we can use it + if mapper.polymorphic_on is not None: + if self.mapped_table is mapper.mapped_table: + self.polymorphic_on = mapper.polymorphic_on + else: + self.polymorphic_on = \ + self.mapped_table.corresponding_column( + mapper.polymorphic_on) + # we can use the parent mapper's _set_polymorphic_identity + # directly; it ensures the polymorphic_identity of the + # instance's mapper is used so is portable to subclasses. + if self.polymorphic_on is not None: + self._set_polymorphic_identity = \ + mapper._set_polymorphic_identity + self._validate_polymorphic_identity = \ + mapper._validate_polymorphic_identity + else: + self._set_polymorphic_identity = None + return + + if setter: + def _set_polymorphic_identity(state): + dict_ = state.dict + state.get_impl(polymorphic_key).set( + state, dict_, + state.manager.mapper.polymorphic_identity, + None) + + def _validate_polymorphic_identity(mapper, state, dict_): + if polymorphic_key in dict_ and \ + dict_[polymorphic_key] not in \ + mapper._acceptable_polymorphic_identities: + util.warn_limited( + "Flushing object %s with " + "incompatible polymorphic identity %r; the " + "object may not refresh and/or load correctly", + (state_str(state), dict_[polymorphic_key]) + ) + + self._set_polymorphic_identity = _set_polymorphic_identity + self._validate_polymorphic_identity = \ + _validate_polymorphic_identity + else: + self._set_polymorphic_identity = None + + _validate_polymorphic_identity = None + + @_memoized_configured_property + def _version_id_prop(self): + if self.version_id_col is not None: + return self._columntoproperty[self.version_id_col] + else: + return None + + @_memoized_configured_property + def _acceptable_polymorphic_identities(self): + identities = set() + + stack = deque([self]) + while stack: + item = stack.popleft() + if item.mapped_table is self.mapped_table: + identities.add(item.polymorphic_identity) + stack.extend(item._inheriting_mappers) + + return identities + + @_memoized_configured_property + def _prop_set(self): + return frozenset(self._props.values()) def _adapt_inherited_property(self, key, prop, init): if not self.concrete: self._configure_property(key, prop, init=False, setparent=False) elif key not in self._props: - self._configure_property(key, ConcreteInheritedProperty(), init=init, setparent=True) - + self._configure_property( + key, + properties.ConcreteInheritedProperty(), + init=init, setparent=True) + def _configure_property(self, key, prop, init=True, setparent=True): self._log("_configure_property(%s, %s)", key, prop.__class__.__name__) if not isinstance(prop, MapperProperty): - # we were passed a Column or a list of Columns; generate a ColumnProperty - columns = util.to_list(prop) - column = columns[0] - if not expression.is_column(column): - raise sa_exc.ArgumentError("%s=%r is not an instance of MapperProperty or Column" % (key, prop)) + prop = self._property_from_column(key, prop) - prop = self._props.get(key, None) - - if isinstance(prop, ColumnProperty): - # TODO: the "property already exists" case is still not well defined here. - # assuming single-column, etc. - - if prop.parent is not self: - # existing ColumnProperty from an inheriting mapper. - # make a copy and append our column to it - prop = prop.copy() - prop.columns.append(column) - self._log("appending to existing ColumnProperty %s" % (key)) - elif prop is None or isinstance(prop, ConcreteInheritedProperty): - mapped_column = [] - for c in columns: - mc = self.mapped_table.corresponding_column(c) - if mc is None: - mc = self.local_table.corresponding_column(c) - if mc is not None: - # if the column is in the local table but not the mapped table, - # this corresponds to adding a column after the fact to the local table. - # [ticket:1523] - self.mapped_table._reset_exported() - mc = self.mapped_table.corresponding_column(c) - if mc is None: - raise sa_exc.ArgumentError("Column '%s' is not represented in mapper's table. " - "Use the `column_property()` function to force this column " - "to be mapped as a read-only attribute." % c) - mapped_column.append(mc) - prop = ColumnProperty(*mapped_column) - else: - raise sa_exc.ArgumentError("WARNING: column '%s' conflicts with property '%r'. " - "To resolve this, map the column to the class under a different " - "name in the 'properties' dictionary. Or, to remove all awareness " - "of the column entirely (including its availability as a foreign key), " - "use the 'include_properties' or 'exclude_properties' mapper arguments " - "to control specifically which table columns get mapped." % (column.key, prop)) - - if isinstance(prop, ColumnProperty): + if isinstance(prop, properties.ColumnProperty): col = self.mapped_table.corresponding_column(prop.columns[0]) - - # if the column is not present in the mapped table, - # test if a column has been added after the fact to the parent table - # (or their parent, etc.) - # [ticket:1570] + + # if the column is not present in the mapped table, + # test if a column has been added after the fact to the + # parent table (or their parent, etc.) [ticket:1570] if col is None and self.inherits: path = [self] for m in self.inherits.iterate_to_root(): @@ -587,74 +1587,76 @@ class Mapper(object): if col is not None: for m2 in path: m2.mapped_table._reset_exported() - col = self.mapped_table.corresponding_column(prop.columns[0]) + col = self.mapped_table.corresponding_column( + prop.columns[0]) break path.append(m) - - # otherwise, col might not be present! the selectable given - # to the mapper need not include "deferred" - # columns (included in zblog tests) + + # subquery expression, column not present in the mapped + # selectable. if col is None: col = prop.columns[0] - # column is coming in after _readonly_props was initialized; check - # for 'readonly' + # column is coming in after _readonly_props was + # initialized; check for 'readonly' if hasattr(self, '_readonly_props') and \ - (not hasattr(col, 'table') or col.table not in self._cols_by_table): - self._readonly_props.add(prop) + (not hasattr(col, 'table') or + col.table not in self._cols_by_table): + self._readonly_props.add(prop) else: - # if column is coming in after _cols_by_table was initialized, ensure the col is in the - # right set - if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]: + # if column is coming in after _cols_by_table was + # initialized, ensure the col is in the right set + if hasattr(self, '_cols_by_table') and \ + col.table in self._cols_by_table and \ + col not in self._cols_by_table[col.table]: self._cols_by_table[col.table].add(col) - - # if this ColumnProperty represents the "polymorphic discriminator" - # column, mark it. We'll need this when rendering columns - # in SELECT statements. + + # if this properties.ColumnProperty represents the "polymorphic + # discriminator" column, mark it. We'll need this when rendering + # columns in SELECT statements. if not hasattr(prop, '_is_polymorphic_discriminator'): - prop._is_polymorphic_discriminator = (col is self.polymorphic_on or prop.columns[0] is self.polymorphic_on) - + prop._is_polymorphic_discriminator = \ + (col is self.polymorphic_on or + prop.columns[0] is self.polymorphic_on) + self.columns[key] = col - for col in prop.columns: + for col in prop.columns + prop._orig_columns: for col in col.proxy_set: self._columntoproperty[col] = prop - elif isinstance(prop, (ComparableProperty, SynonymProperty)) and setparent: - if prop.descriptor is None: - desc = getattr(self.class_, key, None) - if self._is_userland_descriptor(desc): - prop.descriptor = desc - if getattr(prop, 'map_column', False): - if key not in self.mapped_table.c: - raise sa_exc.ArgumentError( - "Can't compile synonym '%s': no column on table '%s' named '%s'" - % (prop.name, self.mapped_table.description, key)) - elif self.mapped_table.c[key] in self._columntoproperty and \ - self._columntoproperty[self.mapped_table.c[key]].key == prop.name: - raise sa_exc.ArgumentError( - "Can't call map_column=True for synonym %r=%r, " - "a ColumnProperty already exists keyed to the name %r " - "for column %r" % - (key, prop.name, prop.name, key) - ) - p = ColumnProperty(self.mapped_table.c[key]) - self._configure_property(prop.name, p, init=init, setparent=setparent) - p._mapped_by_synonym = key - - if key in self._props and getattr(self._props[key], '_mapped_by_synonym', False): - syn = self._props[key]._mapped_by_synonym - raise sa_exc.ArgumentError( - "Can't call map_column=True for synonym %r=%r, " - "a ColumnProperty already exists keyed to the name " - "%r for column %r" % (syn, key, key, syn) - ) - - self._props[key] = prop prop.key = key if setparent: - prop.set_parent(self) + prop.set_parent(self, init) + + if key in self._props and \ + getattr(self._props[key], '_mapped_by_synonym', False): + syn = self._props[key]._mapped_by_synonym + raise sa_exc.ArgumentError( + "Can't call map_column=True for synonym %r=%r, " + "a ColumnProperty already exists keyed to the name " + "%r for column %r" % (syn, key, key, syn) + ) + + if key in self._props and \ + not isinstance(prop, properties.ColumnProperty) and \ + not isinstance( + self._props[key], + ( + properties.ColumnProperty, + properties.ConcreteInheritedProperty) + ): + util.warn("Property %s on %s being replaced with new " + "property %s; the old property will be discarded" % ( + self._props[key], + self, + prop, + )) + oldprop = self._props[key] + self._path_registry.pop(oldprop, None) + + self._props[key] = prop if not self.non_primary: prop.instrument_class(self) @@ -666,57 +1668,84 @@ class Mapper(object): prop.init() prop.post_instrument_class(self) + if self.configured: + self._expire_memoizations() - def compile(self): - """Compile this mapper and all other non-compiled mappers. + def _property_from_column(self, key, prop): + """generate/update a :class:`.ColumnProprerty` given a + :class:`.Column` object. """ - This method checks the local compiled status as well as for - any new mappers that have been defined, and is safe to call - repeatedly. + # we were passed a Column or a list of Columns; + # generate a properties.ColumnProperty + columns = util.to_list(prop) + column = columns[0] + if not expression._is_column(column): + raise sa_exc.ArgumentError( + "%s=%r is not an instance of MapperProperty or Column" + % (key, prop)) - """ - global _new_mappers - if self.compiled and not _new_mappers: - return self + prop = self._props.get(key, None) - _COMPILE_MUTEX.acquire() - try: - try: - global _already_compiling - if _already_compiling: - return - _already_compiling = True - try: + if isinstance(prop, properties.ColumnProperty): + if ( + not self._inherits_equated_pairs or + (prop.columns[0], column) not in self._inherits_equated_pairs + ) and \ + not prop.columns[0].shares_lineage(column) and \ + prop.columns[0] is not self.version_id_col and \ + column is not self.version_id_col: + warn_only = prop.parent is not self + msg = ("Implicitly combining column %s with column " + "%s under attribute '%s'. Please configure one " + "or more attributes for these same-named columns " + "explicitly." % (prop.columns[-1], column, key)) + if warn_only: + util.warn(msg) + else: + raise sa_exc.InvalidRequestError(msg) - # double-check inside mutex - if self.compiled and not _new_mappers: - return self - - # initialize properties on all mappers - # note that _mapper_registry is unordered, which - # may randomly conceal/reveal issues related to - # the order of mapper compilation - for mapper in list(_mapper_registry): - if getattr(mapper, '_compile_failed', False): - raise sa_exc.InvalidRequestError( - "One or more mappers failed to compile. " - "Exception was probably " - "suppressed within a hasattr() call. " - "Message was: %s" % mapper._compile_failed) - if not mapper.compiled: - mapper._post_configure_properties() - - _new_mappers = False - return self - finally: - _already_compiling = False - except: - import sys - exc = sys.exc_info()[1] - self._compile_failed = exc - raise - finally: - _COMPILE_MUTEX.release() + # existing properties.ColumnProperty from an inheriting + # mapper. make a copy and append our column to it + prop = prop.copy() + prop.columns.insert(0, column) + self._log("inserting column to existing list " + "in properties.ColumnProperty %s" % (key)) + return prop + elif prop is None or isinstance(prop, + properties.ConcreteInheritedProperty): + mapped_column = [] + for c in columns: + mc = self.mapped_table.corresponding_column(c) + if mc is None: + mc = self.local_table.corresponding_column(c) + if mc is not None: + # if the column is in the local table but not the + # mapped table, this corresponds to adding a + # column after the fact to the local table. + # [ticket:1523] + self.mapped_table._reset_exported() + mc = self.mapped_table.corresponding_column(c) + if mc is None: + raise sa_exc.ArgumentError( + "When configuring property '%s' on %s, " + "column '%s' is not represented in the mapper's " + "table. Use the `column_property()` function to " + "force this column to be mapped as a read-only " + "attribute." % (key, self, c)) + mapped_column.append(mc) + return properties.ColumnProperty(*mapped_column) + else: + raise sa_exc.ArgumentError( + "WARNING: when configuring property '%s' on %s, " + "column '%s' conflicts with property '%r'. " + "To resolve this, map the column to the class under a " + "different name in the 'properties' dictionary. Or, " + "to remove all awareness of the column entirely " + "(including its availability as a foreign key), " + "use the 'include_properties' or 'exclude_properties' " + "mapper arguments to control specifically which table " + "columns get mapped." % + (key, self, column.key, prop)) def _post_configure_properties(self): """Call the ``init()`` method on all ``MapperProperties`` @@ -724,63 +1753,66 @@ class Mapper(object): This is a deferred configuration step which is intended to execute once all mappers have been constructed. - + """ self._log("_post_configure_properties() started") - l = [(key, prop) for key, prop in self._props.iteritems()] + l = [(key, prop) for key, prop in self._props.items()] for key, prop in l: self._log("initialize prop %s", key) - - if prop.parent is self and not prop._compile_started: + + if prop.parent is self and not prop._configure_started: prop.init() - - if prop._compile_finished: + + if prop._configure_finished: prop.post_instrument_class(self) - + self._log("_post_configure_properties() complete") - self.compiled = True - + self.configured = True + def add_properties(self, dict_of_properties): """Add the given dictionary of properties to this mapper, using `add_property`. """ - for key, value in dict_of_properties.iteritems(): + for key, value in dict_of_properties.items(): self.add_property(key, value) def add_property(self, key, prop): """Add an individual MapperProperty to this mapper. - If the mapper has not been compiled yet, just adds the + If the mapper has not been configured yet, just adds the property to the initial properties dictionary sent to the - constructor. If this Mapper has already been compiled, then - the given MapperProperty is compiled immediately. + constructor. If this Mapper has already been configured, then + the given MapperProperty is configured immediately. """ self._init_properties[key] = prop - self._configure_property(key, prop, init=self.compiled) + self._configure_property(key, prop, init=self.configured) + def _expire_memoizations(self): + for mapper in self.iterate_to_root(): + _memoized_configured_property.expire_instance(mapper) + + @property + def _log_desc(self): + return "(" + self.class_.__name__ + \ + "|" + \ + (self.local_table is not None and + self.local_table.description or + str(self.local_table)) +\ + (self.non_primary and + "|non-primary" or "") + ")" def _log(self, msg, *args): self.logger.info( - "(" + self.class_.__name__ + - "|" + - (self.local_table is not None and - self.local_table.description or - str(self.local_table)) + - (self.non_primary and "|non-primary" or "") + ") " + - msg, *args) + "%s " + msg, *((self._log_desc,) + args) + ) def _log_debug(self, msg, *args): self.logger.debug( - "(" + self.class_.__name__ + - "|" + - (self.local_table is not None and - self.local_table.description - or str(self.local_table)) + - (self.non_primary and "|non-primary" or "") + ") " + - msg, *args) + "%s " + msg, *((self._log_desc,) + args) + ) def __repr__(self): return '' % ( @@ -789,106 +1821,136 @@ class Mapper(object): def __str__(self): return "Mapper|%s|%s%s" % ( self.class_.__name__, - self.local_table is not None and self.local_table.description or None, + self.local_table is not None and + self.local_table.description or None, self.non_primary and "|non-primary" or "" ) def _is_orphan(self, state): - o = False + orphan_possible = False for mapper in self.iterate_to_root(): - for (key, cls) in mapper.delete_orphans: - if attributes.manager_of_class(cls).has_parent( - state, key, optimistic=_state_has_identity(state)): + for (key, cls) in mapper._delete_orphans: + orphan_possible = True + + has_parent = attributes.manager_of_class(cls).has_parent( + state, key, optimistic=state.has_identity) + + if self.legacy_is_orphan and has_parent: return False - o = o or bool(mapper.delete_orphans) - return o + elif not self.legacy_is_orphan and not has_parent: + return True + + if self.legacy_is_orphan: + return orphan_possible + else: + return False def has_property(self, key): return key in self._props - def get_property(self, key, resolve_synonyms=False, raiseerr=True): - """return a MapperProperty associated with the given key.""" + def get_property(self, key, _configure_mappers=True): + """return a MapperProperty associated with the given key. + """ - if not self.compiled: - self.compile() - return self._get_property(key, resolve_synonyms=resolve_synonyms, raiseerr=raiseerr) + if _configure_mappers and Mapper._new_mappers: + configure_mappers() + + try: + return self._props[key] + except KeyError: + raise sa_exc.InvalidRequestError( + "Mapper '%s' has no property '%s'" % (self, key)) + + def get_property_by_column(self, column): + """Given a :class:`.Column` object, return the + :class:`.MapperProperty` which maps this column.""" + + return self._columntoproperty[column] - def _get_property(self, key, resolve_synonyms=False, raiseerr=True): - prop = self._props.get(key, None) - if resolve_synonyms: - while isinstance(prop, SynonymProperty): - prop = self._props.get(prop.name, None) - if prop is None and raiseerr: - raise sa_exc.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key)) - return prop - @property def iterate_properties(self): """return an iterator of all MapperProperty objects.""" - if not self.compiled: - self.compile() - return self._props.itervalues() + if Mapper._new_mappers: + configure_mappers() + return iter(self._props.values()) def _mappers_from_spec(self, spec, selectable): - """given a with_polymorphic() argument, return the set of mappers it represents. + """given a with_polymorphic() argument, return the set of mappers it + represents. - Trims the list of mappers to just those represented within the given selectable, if present. - This helps some more legacy-ish mappings. + Trims the list of mappers to just those represented within the given + selectable, if present. This helps some more legacy-ish mappings. """ if spec == '*': - mappers = list(self.polymorphic_iterator()) + mappers = list(self.self_and_descendants) elif spec: - mappers = [_class_to_mapper(m) for m in util.to_list(spec)] - for m in mappers: + mappers = set() + for m in util.to_list(spec): + m = _class_to_mapper(m) if not m.isa(self): - raise sa_exc.InvalidRequestError("%r does not inherit from %r" % (m, self)) + raise sa_exc.InvalidRequestError( + "%r does not inherit from %r" % + (m, self)) + + if selectable is None: + mappers.update(m.iterate_to_root()) + else: + mappers.add(m) + mappers = [m for m in self.self_and_descendants if m in mappers] else: mappers = [] if selectable is not None: - tables = set(sqlutil.find_tables(selectable, include_aliases=True)) + tables = set(sql_util.find_tables(selectable, + include_aliases=True)) mappers = [m for m in mappers if m.local_table in tables] - return mappers - def _selectable_from_mappers(self, mappers): - """given a list of mappers (assumed to be within this mapper's inheritance hierarchy), - construct an outerjoin amongst those mapper's mapped tables. + def _selectable_from_mappers(self, mappers, innerjoin): + """given a list of mappers (assumed to be within this mapper's + inheritance hierarchy), construct an outerjoin amongst those mapper's + mapped tables. """ - from_obj = self.mapped_table for m in mappers: if m is self: continue if m.concrete: - raise sa_exc.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.") + raise sa_exc.InvalidRequestError( + "'with_polymorphic()' requires 'selectable' argument " + "when concrete-inheriting mappers are used.") elif not m.single: - from_obj = from_obj.outerjoin(m.local_table, m.inherit_condition) + if innerjoin: + from_obj = from_obj.join(m.local_table, + m.inherit_condition) + else: + from_obj = from_obj.outerjoin(m.local_table, + m.inherit_condition) return from_obj - @property + @_memoized_configured_property def _single_table_criterion(self): if self.single and \ - self.inherits and \ - self.polymorphic_on is not None and \ - self.polymorphic_identity is not None: + self.inherits and \ + self.polymorphic_on is not None: return self.polymorphic_on.in_( m.polymorphic_identity - for m in self.polymorphic_iterator()) + for m in self.self_and_descendants) else: return None - - - @util.memoized_property + + @_memoized_configured_property def _with_polymorphic_mappers(self): + if Mapper._new_mappers: + configure_mappers() if not self.with_polymorphic: - return [self] + return [] return self._mappers_from_spec(*self.with_polymorphic) - @util.memoized_property + @_memoized_configured_property def _with_polymorphic_selectable(self): if not self.with_polymorphic: return self.mapped_table @@ -897,24 +1959,152 @@ class Mapper(object): if selectable is not None: return selectable else: - return self._selectable_from_mappers(self._mappers_from_spec(spec, selectable)) + return self._selectable_from_mappers( + self._mappers_from_spec(spec, selectable), + False) - def _with_polymorphic_args(self, spec=None, selectable=False): + with_polymorphic_mappers = _with_polymorphic_mappers + """The list of :class:`.Mapper` objects included in the + default "polymorphic" query. + + """ + + @_memoized_configured_property + def _insert_cols_evaluating_none(self): + return dict( + ( + table, + frozenset( + col.key for col in columns + if col.type.should_evaluate_none + ) + ) + for table, columns in self._cols_by_table.items() + ) + + @_memoized_configured_property + def _insert_cols_as_none(self): + return dict( + ( + table, + frozenset( + col.key for col in columns + if not col.primary_key and + not col.server_default and not col.default + and not col.type.should_evaluate_none) + ) + for table, columns in self._cols_by_table.items() + ) + + @_memoized_configured_property + def _propkey_to_col(self): + return dict( + ( + table, + dict( + (self._columntoproperty[col].key, col) + for col in columns + ) + ) + for table, columns in self._cols_by_table.items() + ) + + @_memoized_configured_property + def _pk_keys_by_table(self): + return dict( + ( + table, + frozenset([col.key for col in pks]) + ) + for table, pks in self._pks_by_table.items() + ) + + @_memoized_configured_property + def _pk_attr_keys_by_table(self): + return dict( + ( + table, + frozenset([self._columntoproperty[col].key for col in pks]) + ) + for table, pks in self._pks_by_table.items() + ) + + @_memoized_configured_property + def _server_default_cols(self): + return dict( + ( + table, + frozenset([ + col.key for col in columns + if col.server_default is not None]) + ) + for table, columns in self._cols_by_table.items() + ) + + @_memoized_configured_property + def _server_default_plus_onupdate_propkeys(self): + result = set() + + for table, columns in self._cols_by_table.items(): + for col in columns: + if ( + ( + col.server_default is not None or + col.server_onupdate is not None + ) and col in self._columntoproperty + ): + result.add(self._columntoproperty[col].key) + + return result + + @_memoized_configured_property + def _server_onupdate_default_cols(self): + return dict( + ( + table, + frozenset([ + col.key for col in columns + if col.server_onupdate is not None]) + ) + for table, columns in self._cols_by_table.items() + ) + + @property + def selectable(self): + """The :func:`.select` construct this :class:`.Mapper` selects from + by default. + + Normally, this is equivalent to :attr:`.mapped_table`, unless + the ``with_polymorphic`` feature is in use, in which case the + full "polymorphic" selectable is returned. + + """ + return self._with_polymorphic_selectable + + def _with_polymorphic_args(self, spec=None, selectable=False, + innerjoin=False): if self.with_polymorphic: if not spec: spec = self.with_polymorphic[0] if selectable is False: selectable = self.with_polymorphic[1] - + elif selectable is False: + selectable = None mappers = self._mappers_from_spec(spec, selectable) if selectable is not None: return mappers, selectable else: - return mappers, self._selectable_from_mappers(mappers) + return mappers, self._selectable_from_mappers(mappers, + innerjoin) + + @_memoized_configured_property + def _polymorphic_properties(self): + return list(self._iterate_polymorphic_properties( + self._with_polymorphic_mappers)) def _iterate_polymorphic_properties(self, mappers=None): - """Return an iterator of MapperProperty objects which will render into a SELECT.""" - + """Return an iterator of MapperProperty objects which will render into + a SELECT.""" if mappers is None: mappers = self._with_polymorphic_mappers @@ -926,29 +2116,188 @@ class Mapper(object): # from other mappers, as these are sometimes dependent on that # mapper's polymorphic selectable (which we don't want rendered) for c in util.unique_list( - chain(*[list(mapper.iterate_properties) for mapper in [self] + mappers]) + chain(*[ + list(mapper.iterate_properties) for mapper in + [self] + mappers + ]) ): if getattr(c, '_is_polymorphic_discriminator', False) and \ - (self.polymorphic_on is None or c.columns[0] is not self.polymorphic_on): - continue + (self.polymorphic_on is None or + c.columns[0] is not self.polymorphic_on): + continue yield c - - @property - def properties(self): - raise NotImplementedError("Public collection of MapperProperty objects is " - "provided by the get_property() and iterate_properties accessors.") - @util.memoized_property + @_memoized_configured_property + def attrs(self): + """A namespace of all :class:`.MapperProperty` objects + associated this mapper. + + This is an object that provides each property based on + its key name. For instance, the mapper for a + ``User`` class which has ``User.name`` attribute would + provide ``mapper.attrs.name``, which would be the + :class:`.ColumnProperty` representing the ``name`` + column. The namespace object can also be iterated, + which would yield each :class:`.MapperProperty`. + + :class:`.Mapper` has several pre-filtered views + of this attribute which limit the types of properties + returned, inclding :attr:`.synonyms`, :attr:`.column_attrs`, + :attr:`.relationships`, and :attr:`.composites`. + + .. warning:: + + The :attr:`.Mapper.attrs` accessor namespace is an + instance of :class:`.OrderedProperties`. This is + a dictionary-like object which includes a small number of + named methods such as :meth:`.OrderedProperties.items` + and :meth:`.OrderedProperties.values`. When + accessing attributes dynamically, favor using the dict-access + scheme, e.g. ``mapper.attrs[somename]`` over + ``getattr(mapper.attrs, somename)`` to avoid name collisions. + + .. seealso:: + + :attr:`.Mapper.all_orm_descriptors` + + """ + if Mapper._new_mappers: + configure_mappers() + return util.ImmutableProperties(self._props) + + @_memoized_configured_property + def all_orm_descriptors(self): + """A namespace of all :class:`.InspectionAttr` attributes associated + with the mapped class. + + These attributes are in all cases Python :term:`descriptors` + associated with the mapped class or its superclasses. + + This namespace includes attributes that are mapped to the class + as well as attributes declared by extension modules. + It includes any Python descriptor type that inherits from + :class:`.InspectionAttr`. This includes + :class:`.QueryableAttribute`, as well as extension types such as + :class:`.hybrid_property`, :class:`.hybrid_method` and + :class:`.AssociationProxy`. + + To distinguish between mapped attributes and extension attributes, + the attribute :attr:`.InspectionAttr.extension_type` will refer + to a constant that distinguishes between different extension types. + + When dealing with a :class:`.QueryableAttribute`, the + :attr:`.QueryableAttribute.property` attribute refers to the + :class:`.MapperProperty` property, which is what you get when + referring to the collection of mapped properties via + :attr:`.Mapper.attrs`. + + .. warning:: + + The :attr:`.Mapper.all_orm_descriptors` accessor namespace is an + instance of :class:`.OrderedProperties`. This is + a dictionary-like object which includes a small number of + named methods such as :meth:`.OrderedProperties.items` + and :meth:`.OrderedProperties.values`. When + accessing attributes dynamically, favor using the dict-access + scheme, e.g. ``mapper.all_orm_descriptors[somename]`` over + ``getattr(mapper.all_orm_descriptors, somename)`` to avoid name + collisions. + + .. versionadded:: 0.8.0 + + .. seealso:: + + :attr:`.Mapper.attrs` + + """ + return util.ImmutableProperties( + dict(self.class_manager._all_sqla_attributes())) + + @_memoized_configured_property + def synonyms(self): + """Return a namespace of all :class:`.SynonymProperty` + properties maintained by this :class:`.Mapper`. + + .. seealso:: + + :attr:`.Mapper.attrs` - namespace of all :class:`.MapperProperty` + objects. + + """ + return self._filter_properties(properties.SynonymProperty) + + @_memoized_configured_property + def column_attrs(self): + """Return a namespace of all :class:`.ColumnProperty` + properties maintained by this :class:`.Mapper`. + + .. seealso:: + + :attr:`.Mapper.attrs` - namespace of all :class:`.MapperProperty` + objects. + + """ + return self._filter_properties(properties.ColumnProperty) + + @_memoized_configured_property + def relationships(self): + """A namespace of all :class:`.RelationshipProperty` properties + maintained by this :class:`.Mapper`. + + .. warning:: + + the :attr:`.Mapper.relationships` accessor namespace is an + instance of :class:`.OrderedProperties`. This is + a dictionary-like object which includes a small number of + named methods such as :meth:`.OrderedProperties.items` + and :meth:`.OrderedProperties.values`. When + accessing attributes dynamically, favor using the dict-access + scheme, e.g. ``mapper.relationships[somename]`` over + ``getattr(mapper.relationships, somename)`` to avoid name + collisions. + + .. seealso:: + + :attr:`.Mapper.attrs` - namespace of all :class:`.MapperProperty` + objects. + + """ + return self._filter_properties(properties.RelationshipProperty) + + @_memoized_configured_property + def composites(self): + """Return a namespace of all :class:`.CompositeProperty` + properties maintained by this :class:`.Mapper`. + + .. seealso:: + + :attr:`.Mapper.attrs` - namespace of all :class:`.MapperProperty` + objects. + + """ + return self._filter_properties(properties.CompositeProperty) + + def _filter_properties(self, type_): + if Mapper._new_mappers: + configure_mappers() + return util.ImmutableProperties(util.OrderedDict( + (k, v) for k, v in self._props.items() + if isinstance(v, type_) + )) + + @_memoized_configured_property def _get_clause(self): """create a "get clause" based on the primary key. this is used by query.get() and many-to-one lazyloads to load this item by primary key. """ - params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) for primary_key in self.primary_key] - return sql.and_(*[k==v for (k, v) in params]), util.column_dict(params) + params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) + for primary_key in self.primary_key] + return sql.and_(*[k == v for (k, v) in params]), \ + util.column_dict(params) - @util.memoized_property + @_memoized_configured_property def _equivalent_columns(self): """Create a map of all *equivalent* columns, based on the determination of column pairs that are equated to @@ -970,6 +2319,7 @@ class Mapper(object): """ result = util.column_dict() + def visit_binary(binary): if binary.operator == operators.eq: if binary.left in result: @@ -980,48 +2330,63 @@ class Mapper(object): result[binary.right].add(binary.left) else: result[binary.right] = util.column_set((binary.left,)) - for mapper in self.base_mapper.polymorphic_iterator(): + for mapper in self.base_mapper.self_and_descendants: if mapper.inherit_condition is not None: - visitors.traverse(mapper.inherit_condition, {}, {'binary':visit_binary}) + visitors.traverse( + mapper.inherit_condition, {}, + {'binary': visit_binary}) return result def _is_userland_descriptor(self, obj): - return not isinstance(obj, (MapperProperty, attributes.InstrumentedAttribute)) and hasattr(obj, '__get__') + if isinstance(obj, (_MappedAttribute, + instrumentation.ClassManager, + expression.ColumnElement)): + return False + else: + return True - def _should_exclude(self, name, assigned_name, local): - """determine whether a particular property should be implicitly present on the class. + def _should_exclude(self, name, assigned_name, local, column): + """determine whether a particular property should be implicitly + present on the class. - This occurs when properties are propagated from an inherited class, or are - applied from the columns present in the mapped table. + This occurs when properties are propagated from an inherited class, or + are applied from the columns present in the mapped table. """ - # check for descriptors, either local or from - # an inherited class + # check for class-bound attributes and/or descriptors, + # either local or from an inherited class if local: - if self.class_.__dict__.get(assigned_name, None) is not None\ - and self._is_userland_descriptor(self.class_.__dict__[assigned_name]): + if self.class_.__dict__.get(assigned_name, None) is not None \ + and self._is_userland_descriptor( + self.class_.__dict__[assigned_name]): return True else: - if getattr(self.class_, assigned_name, None) is not None\ - and self._is_userland_descriptor(getattr(self.class_, assigned_name)): + if getattr(self.class_, assigned_name, None) is not None \ + and self._is_userland_descriptor( + getattr(self.class_, assigned_name)): return True - if (self.include_properties is not None and - name not in self.include_properties): + if self.include_properties is not None and \ + name not in self.include_properties and \ + (column is None or column not in self.include_properties): self._log("not including property %s" % (name)) return True - if (self.exclude_properties is not None and - name in self.exclude_properties): + if self.exclude_properties is not None and \ + ( + name in self.exclude_properties or + (column is not None and column in self.exclude_properties) + ): self._log("excluding property %s" % (name)) return True return False def common_parent(self, other): - """Return true if the given mapper shares a common inherited parent as this mapper.""" + """Return true if the given mapper shares a + common inherited parent as this mapper.""" return self.base_mapper is other.base_mapper @@ -1046,6 +2411,22 @@ class Mapper(object): yield m m = m.inherits + @_memoized_configured_property + def self_and_descendants(self): + """The collection including this mapper and all descendant mappers. + + This includes not just the immediately inheriting mappers but + all their inheriting mappers as well. + + """ + descendants = [] + stack = deque([self]) + while stack: + item = stack.popleft() + descendants.append(item) + stack.extend(item._inheriting_mappers) + return util.WeakSequence(descendants) + def polymorphic_iterator(self): """Iterate through the collection including this mapper and all descendant mappers. @@ -1055,109 +2436,173 @@ class Mapper(object): To iterate through an entire hierarchy, use ``mapper.base_mapper.polymorphic_iterator()``. - + """ - stack = deque([self]) - while stack: - item = stack.popleft() - yield item - stack.extend(item._inheriting_mappers) + return iter(self.self_and_descendants) def primary_mapper(self): - """Return the primary mapper corresponding to this mapper's class key (class).""" - + """Return the primary mapper corresponding to this mapper's class key + (class).""" + return self.class_manager.mapper + @property + def primary_base_mapper(self): + return self.class_manager.mapper.base_mapper + + def _result_has_identity_key(self, result, adapter=None): + pk_cols = self.primary_key + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + for col in pk_cols: + if not result._has_key(col): + return False + else: + return True + def identity_key_from_row(self, row, adapter=None): """Return an identity-map key for use in storing/retrieving an item from the identity map. - row - A ``sqlalchemy.engine.base.RowProxy`` instance or a - dictionary corresponding result-set ``ColumnElement`` - instances to their values within a row. + :param row: A :class:`.RowProxy` instance. The columns which are + mapped by this :class:`.Mapper` should be locatable in the row, + preferably via the :class:`.Column` object directly (as is the case + when a :func:`.select` construct is executed), or via string names of + the form ``_``. """ pk_cols = self.primary_key if adapter: pk_cols = [adapter.columns[c] for c in pk_cols] - return (self._identity_class, tuple(row[column] for column in pk_cols)) + return self._identity_class, \ + tuple(row[column] for column in pk_cols) def identity_key_from_primary_key(self, primary_key): """Return an identity-map key for use in storing/retrieving an item from an identity map. - primary_key - A list of values indicating the identifier. + :param primary_key: A list of values indicating the identifier. """ - return (self._identity_class, tuple(util.to_list(primary_key))) + return self._identity_class, tuple(primary_key) def identity_key_from_instance(self, instance): """Return the identity key for the given instance, based on its primary key attributes. + If the instance's state is expired, calling this method + will result in a database check to see if the object has been deleted. + If the row no longer exists, + :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + This value is typically also found on the instance state under the attribute name `key`. """ - return self.identity_key_from_primary_key(self.primary_key_from_instance(instance)) + return self.identity_key_from_primary_key( + self.primary_key_from_instance(instance)) def _identity_key_from_state(self, state): - return self.identity_key_from_primary_key(self._primary_key_from_state(state)) + dict_ = state.dict + manager = state.manager + return self._identity_class, tuple([ + manager[self._columntoproperty[col].key]. + impl.get(state, dict_, attributes.PASSIVE_RETURN_NEVER_SET) + for col in self.primary_key + ]) def primary_key_from_instance(self, instance): """Return the list of primary key values for the given instance. + If the instance's state is expired, calling this method + will result in a database check to see if the object has been deleted. + If the row no longer exists, + :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + """ state = attributes.instance_state(instance) - return self._primary_key_from_state(state) + return self._primary_key_from_state(state, attributes.PASSIVE_OFF) - def _primary_key_from_state(self, state): - return [self._get_state_attr_by_column(state, column) for column in self.primary_key] + def _primary_key_from_state( + self, state, passive=attributes.PASSIVE_RETURN_NEVER_SET): + dict_ = state.dict + manager = state.manager + return [ + manager[prop.key]. + impl.get(state, dict_, passive) + for prop in self._identity_key_props + ] - def _get_col_to_prop(self, column): - try: - return self._columntoproperty[column] - except KeyError: - prop = self._props.get(column.key, None) - if prop: - raise orm_exc.UnmappedColumnError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop))) - else: - raise orm_exc.UnmappedColumnError("No column %s is configured on mapper %s..." % (column, self)) + @_memoized_configured_property + def _identity_key_props(self): + return [self._columntoproperty[col] for col in self.primary_key] - # TODO: improve names? - def _get_state_attr_by_column(self, state, column): - return self._get_col_to_prop(column).getattr(state, column) + @_memoized_configured_property + def _all_pk_props(self): + collection = set() + for table in self.tables: + collection.update(self._pks_by_table[table]) + return collection - def _set_state_attr_by_column(self, state, column, value): - return self._get_col_to_prop(column).setattr(state, value, column) + @_memoized_configured_property + def _should_undefer_in_wildcard(self): + cols = set(self.primary_key) + if self.polymorphic_on is not None: + cols.add(self.polymorphic_on) + return cols + + @_memoized_configured_property + def _primary_key_propkeys(self): + return set([prop.key for prop in self._all_pk_props]) + + def _get_state_attr_by_column( + self, state, dict_, column, + passive=attributes.PASSIVE_RETURN_NEVER_SET): + prop = self._columntoproperty[column] + return state.manager[prop.key].impl.get(state, dict_, passive=passive) + + def _set_committed_state_attr_by_column(self, state, dict_, column, value): + prop = self._columntoproperty[column] + state.manager[prop.key].impl.set_committed_value(state, dict_, value) + + def _set_state_attr_by_column(self, state, dict_, column, value): + prop = self._columntoproperty[column] + state.manager[prop.key].impl.set(state, dict_, value, None) def _get_committed_attr_by_column(self, obj, column): state = attributes.instance_state(obj) - return self._get_committed_state_attr_by_column(state, column) + dict_ = attributes.instance_dict(obj) + return self._get_committed_state_attr_by_column( + state, dict_, column, passive=attributes.PASSIVE_OFF) - def _get_committed_state_attr_by_column(self, state, column, passive=False): - return self._get_col_to_prop(column).getcommitted(state, column, passive=passive) + def _get_committed_state_attr_by_column( + self, state, dict_, column, + passive=attributes.PASSIVE_RETURN_NEVER_SET): + + prop = self._columntoproperty[column] + return state.manager[prop.key].impl.\ + get_committed_value(state, dict_, passive=passive) def _optimized_get_statement(self, state, attribute_names): - """assemble a WHERE clause which retrieves a given state by primary key, using a minimized set of tables. - - Applies to a joined-table inheritance mapper where the + """assemble a WHERE clause which retrieves a given state by primary + key, using a minimized set of tables. + + Applies to a joined-table inheritance mapper where the requested attribute names are only present on joined tables, - not the base table. The WHERE clause attempts to include + not the base table. The WHERE clause attempts to include only those tables to minimize joins. - + """ props = self._props - - tables = set(chain(* - (sqlutil.find_tables(props[key].columns[0], check_columns=True) - for key in attribute_names) - )) - + + tables = set(chain( + *[sql_util.find_tables(c, check_columns=True) + for key in attribute_names + for c in props[key].columns] + )) + if self.base_mapper.local_table in tables: return None @@ -1171,15 +2616,23 @@ class Mapper(object): return if leftcol.table not in tables: - leftval = self._get_committed_state_attr_by_column(state, leftcol, passive=True) - if leftval is attributes.PASSIVE_NO_RESULT: + leftval = self._get_committed_state_attr_by_column( + state, state.dict, + leftcol, + passive=attributes.PASSIVE_NO_INITIALIZE) + if leftval in orm_util._none_set: raise ColumnsNotAvailable() - binary.left = sql.bindparam(None, leftval, type_=binary.right.type) + binary.left = sql.bindparam(None, leftval, + type_=binary.right.type) elif rightcol.table not in tables: - rightval = self._get_committed_state_attr_by_column(state, rightcol, passive=True) - if rightval is attributes.PASSIVE_NO_RESULT: + rightval = self._get_committed_state_attr_by_column( + state, state.dict, + rightcol, + passive=attributes.PASSIVE_NO_INITIALIZE) + if rightval in orm_util._none_set: raise ColumnsNotAvailable() - binary.right = sql.bindparam(None, rightval, type_=binary.right.type) + binary.right = sql.bindparam(None, rightval, + type_=binary.right.type) allconds = [] @@ -1188,8 +2641,16 @@ class Mapper(object): for mapper in reversed(list(self.iterate_to_root())): if mapper.local_table in tables: start = True + elif not isinstance(mapper.local_table, + expression.TableClause): + return None if start and not mapper.single: - allconds.append(visitors.cloned_traverse(mapper.inherit_condition, {}, {'binary':visit_binary})) + allconds.append(visitors.cloned_traverse( + mapper.inherit_condition, + {}, + {'binary': visit_binary} + ) + ) except ColumnsNotAvailable: return None @@ -1204,654 +2665,220 @@ class Mapper(object): """Iterate each element and its mapper in an object graph, for all relationships that meet the given cascade rule. - ``type\_``: - The name of the cascade rule (i.e. save-update, delete, - etc.) + :param type_: + The name of the cascade rule (i.e. ``"save-update"``, ``"delete"``, + etc.). - ``state``: + .. note:: the ``"all"`` cascade is not accepted here. For a generic + object traversal function, see :ref:`faq_walk_objects`. + + :param state: The lead InstanceState. child items will be processed per the relationships defined for this object's mapper. - the return value are object instances; this provides a strong - reference so that they don't fall out of scope immediately. + :return: the method yields individual object instances. + + .. seealso:: + + :ref:`unitofwork_cascades` + + :ref:`faq_walk_objects` - illustrates a generic function to + traverse all objects without relying on cascades. """ - visited_instances = util.IdentitySet() - visitables = [(self._props.itervalues(), 'property', state)] + visited_states = set() + prp, mpp = object(), object() + + visitables = deque([(deque(self._props.values()), prp, + state, state.dict)]) while visitables: - iterator, item_type, parent_state = visitables[-1] - try: - if item_type == 'property': - prop = iterator.next() - visitables.append((prop.cascade_iterator(type_, parent_state, visited_instances, halt_on), 'mapper', None)) - elif item_type == 'mapper': - instance, instance_mapper, corresponding_state = iterator.next() - yield (instance, instance_mapper) - visitables.append((instance_mapper._props.itervalues(), 'property', corresponding_state)) - except StopIteration: + iterator, item_type, parent_state, parent_dict = visitables[-1] + if not iterator: visitables.pop() + continue - @util.memoized_property + if item_type is prp: + prop = iterator.popleft() + if type_ not in prop.cascade: + continue + queue = deque(prop.cascade_iterator( + type_, parent_state, parent_dict, + visited_states, halt_on)) + if queue: + visitables.append((queue, mpp, None, None)) + elif item_type is mpp: + instance, instance_mapper, corresponding_state, \ + corresponding_dict = iterator.popleft() + yield instance, instance_mapper, \ + corresponding_state, corresponding_dict + visitables.append((deque(instance_mapper._props.values()), + prp, corresponding_state, + corresponding_dict)) + + @_memoized_configured_property + def _compiled_cache(self): + return util.LRUCache(self._compiled_cache_size) + + @_memoized_configured_property def _sorted_tables(self): table_to_mapper = {} - for mapper in self.base_mapper.polymorphic_iterator(): + + for mapper in self.base_mapper.self_and_descendants: for t in mapper.tables: - table_to_mapper[t] = mapper - - sorted_ = sqlutil.sort_tables(table_to_mapper.iterkeys()) + table_to_mapper.setdefault(t, mapper) + + extra_dependencies = [] + for table, mapper in table_to_mapper.items(): + super_ = mapper.inherits + if super_: + extra_dependencies.extend([ + (super_table, table) + for super_table in super_.tables + ]) + + def skip(fk): + # attempt to skip dependencies that are not + # significant to the inheritance chain + # for two tables that are related by inheritance. + # while that dependency may be important, it's technically + # not what we mean to sort on here. + parent = table_to_mapper.get(fk.parent.table) + dep = table_to_mapper.get(fk.column.table) + if parent is not None and \ + dep is not None and \ + dep is not parent and \ + dep.inherit_condition is not None: + cols = set(sql_util._find_columns(dep.inherit_condition)) + if parent.inherit_condition is not None: + cols = cols.union(sql_util._find_columns( + parent.inherit_condition)) + return fk.parent not in cols and fk.column not in cols + else: + return fk.parent not in cols + return False + + sorted_ = sql_util.sort_tables(table_to_mapper, + skip_fn=skip, + extra_dependencies=extra_dependencies) + ret = util.OrderedDict() for t in sorted_: ret[t] = table_to_mapper[t] return ret - def _save_obj(self, states, uowtransaction, postupdate=False, - post_update_cols=None, single=False): - """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. + def _memo(self, key, callable_): + if key in self._memoized_values: + return self._memoized_values[key] + else: + self._memoized_values[key] = value = callable_() + return value - This is called within the context of a UOWTransaction during a - flush operation. + @util.memoized_property + def _table_to_equated(self): + """memoized map of tables to collections of columns to be + synchronized upwards to the base mapper.""" - `_save_obj` issues SQL statements not just for instances mapped - directly by this mapper, but for instances mapped by all - inheriting mappers as well. This is to maintain proper insert - ordering among a polymorphic chain of instances. Therefore - _save_obj is typically called only on a *base mapper*, or a - mapper which does not inherit from any other mapper. - - """ - # if batch=false, call _save_obj separately for each object - if not single and not self.batch: - for state in _sort_states(states): - self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) + result = util.defaultdict(list) + + for table in self._sorted_tables: + cols = set(table.c) + for m in self.iterate_to_root(): + if m._inherits_equated_pairs and \ + cols.intersection( + util.reduce(set.union, + [l.proxy_set for l, r in + m._inherits_equated_pairs]) + ): + result[table].append((m, m._inherits_equated_pairs)) + + return result + + +def configure_mappers(): + """Initialize the inter-mapper relationships of all mappers that + have been constructed thus far. + + This function can be called any number of times, but in + most cases is invoked automatically, the first time mappings are used, + as well as whenever mappings are used and additional not-yet-configured + mappers have been constructed. + + Points at which this occur include when a mapped class is instantiated + into an instance, as well as when the :meth:`.Session.query` method + is used. + + The :func:`.configure_mappers` function provides several event hooks + that can be used to augment its functionality. These methods include: + + * :meth:`.MapperEvents.before_configured` - called once before + :func:`.configure_mappers` does any work; this can be used to establish + additional options, properties, or related mappings before the operation + proceeds. + + * :meth:`.MapperEvents.mapper_configured` - called as each indivudal + :class:`.Mapper` is configured within the process; will include all + mapper state except for backrefs set up by other mappers that are still + to be configured. + + * :meth:`.MapperEvents.after_configured` - called once after + :func:`.configure_mappers` is complete; at this stage, all + :class:`.Mapper` objects that are known to SQLAlchemy will be fully + configured. Note that the calling application may still have other + mappings that haven't been produced yet, such as if they are in modules + as yet unimported. + + """ + + if not Mapper._new_mappers: + return + + _CONFIGURE_MUTEX.acquire() + try: + global _already_compiling + if _already_compiling: return - - # if session has a connection callable, - # organize individual states with the connection to use for insert/update - tups = [] - if 'connection_callable' in uowtransaction.mapper_flush_opts: - connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - for state in _sort_states(states): - m = _state_mapper(state) - tups.append( - ( - state, - m, - connection_callable(self, state.obj()), - _state_has_identity(state), - state.key or m._identity_key_from_state(state) - ) - ) - else: - connection = uowtransaction.transaction.connection(self) - for state in _sort_states(states): - m = _state_mapper(state) - tups.append( - ( - state, - m, - connection, - _state_has_identity(state), - state.key or m._identity_key_from_state(state) - ) - ) - - if not postupdate: - # call before_XXX extensions - for state, mapper, connection, has_identity, instance_key in tups: - if not has_identity: - if 'before_insert' in mapper.extension: - mapper.extension.before_insert(mapper, connection, state.obj()) - else: - if 'before_update' in mapper.extension: - mapper.extension.before_update(mapper, connection, state.obj()) - - row_switches = {} - if not postupdate: - for state, mapper, connection, has_identity, instance_key in tups: - # detect if we have a "pending" instance (i.e. has no instance_key attached to it), - # and another instance with the same identity key already exists as persistent. convert to an - # UPDATE if so. - if not has_identity and instance_key in uowtransaction.session.identity_map: - instance = uowtransaction.session.identity_map[instance_key] - existing = attributes.instance_state(instance) - if not uowtransaction.is_deleted(existing): - raise orm_exc.FlushError( - "New instance %s with identity key %s conflicts " - "with persistent instance %s" % - (state_str(state), instance_key, state_str(existing))) - - self._log_debug( - "detected row switch for identity %s. will update %s, remove %s from " - "transaction", instance_key, state_str(state), state_str(existing)) - - # remove the "delete" flag from the existing element - uowtransaction.set_row_switch(existing) - row_switches[state] = existing - - table_to_mapper = self._sorted_tables - - for table in table_to_mapper.iterkeys(): - insert = [] - update = [] - - for state, mapper, connection, has_identity, instance_key in tups: - if table not in mapper._pks_by_table: - continue - - pks = mapper._pks_by_table[table] - - isinsert = not has_identity and not postupdate and state not in row_switches - - params = {} - value_params = {} - hasdata = False - - if isinsert: - for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col: - params[col.key] = mapper.version_id_generator(None) - elif mapper.polymorphic_on is not None and \ - mapper.polymorphic_on.shares_lineage(col): - value = mapper.polymorphic_identity - if ((col.default is None and - col.server_default is None) or - value is not None): - params[col.key] = value - elif col in pks: - value = mapper._get_state_attr_by_column(state, col) - if value is not None: - params[col.key] = value - else: - value = mapper._get_state_attr_by_column(state, col) - if ((col.default is None and - col.server_default is None) or - value is not None): - if isinstance(value, sql.ClauseElement): - value_params[col] = value - else: - params[col.key] = value - insert.append((state, params, mapper, connection, value_params)) - else: - for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col: - params[col._label] = mapper._get_state_attr_by_column(row_switches.get(state, state), col) - params[col.key] = mapper.version_id_generator(params[col._label]) - for prop in mapper._columntoproperty.itervalues(): - history = attributes.get_state_history(state, prop.key, passive=True) - if history.added: - hasdata = True - elif mapper.polymorphic_on is not None and \ - mapper.polymorphic_on.shares_lineage(col) and col not in pks: - pass - else: - if post_update_cols is not None and col not in post_update_cols: - if col in pks: - params[col._label] = mapper._get_state_attr_by_column(state, col) - continue - - prop = mapper._columntoproperty[col] - history = attributes.get_state_history(state, prop.key, passive=True) - if history.added: - if isinstance(history.added[0], sql.ClauseElement): - value_params[col] = history.added[0] - else: - params[col.key] = prop.get_col_value(col, history.added[0]) - - if col in pks: - if history.deleted: - # if passive_updates and sync detected this was a - # pk->pk sync, use the new value to locate the row, - # since the DB would already have set this - if ("pk_cascaded", state, col) in \ - uowtransaction.attributes: - params[col._label] = \ - prop.get_col_value(col, history.added[0]) - else: - # use the old value to locate the row - params[col._label] = \ - prop.get_col_value(col, history.deleted[0]) - hasdata = True - else: - # row switch logic can reach us here - # remove the pk from the update params so the update doesn't - # attempt to include the pk in the update statement - del params[col.key] - params[col._label] = \ - prop.get_col_value(col, history.added[0]) - else: - hasdata = True - elif col in pks: - params[col._label] = mapper._get_state_attr_by_column(state, col) - if hasdata: - update.append((state, params, mapper, connection, value_params)) - - if update: - mapper = table_to_mapper[table] - clause = sql.and_() - - for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, type_=col.type)) - - if mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col): - - clause.clauses.append(mapper.version_id_col ==\ - sql.bindparam(mapper.version_id_col._label, type_=col.type)) - - statement = table.update(clause) - - rows = 0 - for state, params, mapper, connection, value_params in update: - c = connection.execute(statement.values(value_params), params) - mapper._postfetch(uowtransaction, connection, table, - state, c, c.last_updated_params(), value_params) - - rows += c.rowcount - - if connection.dialect.supports_sane_rowcount: - if rows != len(update): - raise orm_exc.ConcurrentModificationError( - "Updated rowcount %d does not match number of objects updated %d" % - (rows, len(update))) - - elif mapper.version_id_col is not None: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % c.dialect.dialect_description, - stacklevel=12) - - if insert: - statement = table.insert() - for state, params, mapper, connection, value_params in insert: - c = connection.execute(statement.values(value_params), params) - primary_key = c.inserted_primary_key - - if primary_key is not None: - # set primary key attributes - for i, col in enumerate(mapper._pks_by_table[table]): - if mapper._get_state_attr_by_column(state, col) is None and \ - len(primary_key) > i: - mapper._set_state_attr_by_column(state, col, primary_key[i]) - - mapper._postfetch(uowtransaction, connection, table, - state, c, c.last_inserted_params(), value_params) - - - if not postupdate: - for state, mapper, connection, has_identity, instance_key in tups: - - # expire readonly attributes - readonly = state.unmodified.intersection( - p.key for p in mapper._readonly_props - ) - - if readonly: - _expire_state(state, state.dict, readonly) - - # if specified, eagerly refresh whatever has - # been expired. - if self.eager_defaults and state.unloaded: - state.key = self._identity_key_from_state(state) - uowtransaction.session.query(self)._get( - state.key, refresh_state=state, - only_load_props=state.unloaded) - - # call after_XXX extensions - if not has_identity: - if 'after_insert' in mapper.extension: - mapper.extension.after_insert(mapper, connection, state.obj()) - else: - if 'after_update' in mapper.extension: - mapper.extension.after_update(mapper, connection, state.obj()) - - def _postfetch(self, uowtransaction, connection, table, - state, resultproxy, params, value_params): - """Expire attributes in need of newly persisted database state.""" - - postfetch_cols = resultproxy.postfetch_cols() - generated_cols = list(resultproxy.prefetch_cols()) - - if self.polymorphic_on is not None: - po = table.corresponding_column(self.polymorphic_on) - if po is not None: - generated_cols.append(po) - - if self.version_id_col is not None: - generated_cols.append(self.version_id_col) - - for c in generated_cols: - if c.key in params and c in self._columntoproperty: - self._set_state_attr_by_column(state, c, params[c.key]) - - deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]] - - if deferred_props: - _expire_state(state, state.dict, deferred_props) - - # synchronize newly inserted ids from one table to the next - # TODO: this still goes a little too often. would be nice to - # have definitive list of "columns that changed" here - cols = set(table.c) - for m in self.iterate_to_root(): - if m._inherits_equated_pairs and \ - cols.intersection([l for l, r in m._inherits_equated_pairs]): - sync.populate(state, m, state, m, - m._inherits_equated_pairs, - uowtransaction, - self.passive_updates) - - def _delete_obj(self, states, uowtransaction): - """Issue ``DELETE`` statements for a list of objects. - - This is called within the context of a UOWTransaction during a - flush operation. - - """ - if 'connection_callable' in uowtransaction.mapper_flush_opts: - connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in _sort_states(states)] - else: - connection = uowtransaction.transaction.connection(self) - tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)] - - for state, mapper, connection in tups: - if 'before_delete' in mapper.extension: - mapper.extension.before_delete(mapper, connection, state.obj()) - - table_to_mapper = self._sorted_tables - - for table in reversed(table_to_mapper.keys()): - delete = {} - for state, mapper, connection in tups: - if table not in mapper._pks_by_table: - continue - - params = {} - if not _state_has_identity(state): - continue - else: - delete.setdefault(connection, []).append(params) - for col in mapper._pks_by_table[table]: - params[col.key] = mapper._get_state_attr_by_column(state, col) - if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): - params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col) - - for connection, del_objects in delete.iteritems(): - mapper = table_to_mapper[table] - clause = sql.and_() - for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col.key, type_=col.type)) - if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): - clause.clauses.append( - mapper.version_id_col == - sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type)) - statement = table.delete(clause) - c = connection.execute(statement, del_objects) - if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects): - raise orm_exc.ConcurrentModificationError("Deleted rowcount %d does not match " - "number of objects deleted %d" % (c.rowcount, len(del_objects))) - - for state, mapper, connection in tups: - if 'after_delete' in mapper.extension: - mapper.extension.after_delete(mapper, connection, state.obj()) - - def _register_dependencies(self, uowcommit): - """Register ``DependencyProcessor`` instances with a - ``unitofwork.UOWTransaction``. - - This call `register_dependencies` on all attached - ``MapperProperty`` instances. - - """ - for dep in self._props.values() + self._dependency_processors: - dep.register_dependencies(uowcommit) - - def _register_processors(self, uowcommit): - for dep in self._props.values() + self._dependency_processors: - dep.register_processors(uowcommit) - - def _instance_processor(self, context, path, adapter, - polymorphic_from=None, extension=None, - only_load_props=None, refresh_state=None, - polymorphic_discriminator=None): - - """Produce a mapper level row processor callable - which processes rows into mapped instances.""" - - pk_cols = self.primary_key - - if polymorphic_from or refresh_state: - polymorphic_on = None - else: - if polymorphic_discriminator is not None: - polymorphic_on = polymorphic_discriminator - else: - polymorphic_on = self.polymorphic_on - polymorphic_instances = util.PopulateDict( - self._configure_subclass_mapper(context, path, adapter) - ) - - version_id_col = self.version_id_col - - if adapter: - pk_cols = [adapter.columns[c] for c in pk_cols] - if polymorphic_on is not None: - polymorphic_on = adapter.columns[polymorphic_on] - if version_id_col is not None: - version_id_col = adapter.columns[version_id_col] - - identity_class = self._identity_class - def identity_key(row): - return (identity_class, tuple([row[column] for column in pk_cols])) - - new_populators = [] - existing_populators = [] - load_path = context.query._current_path + path - - def populate_state(state, dict_, row, isnew, only_load_props): - if isnew: - if context.propagate_options: - state.load_options = context.propagate_options - if state.load_options: - state.load_path = load_path - - if not new_populators: - new_populators[:], existing_populators[:] = \ - self._populators(context, path, row, adapter) - - if isnew: - populators = new_populators - else: - populators = existing_populators - - if only_load_props: - populators = [p for p in populators if p[0] in only_load_props] - - for key, populator in populators: - populator(state, dict_, row) - - session_identity_map = context.session.identity_map - - if not extension: - extension = self.extension - - translate_row = extension.get('translate_row', None) - create_instance = extension.get('create_instance', None) - populate_instance = extension.get('populate_instance', None) - append_result = extension.get('append_result', None) - populate_existing = context.populate_existing or self.always_refresh - if self.allow_partial_pks: - is_not_primary_key = _none_set.issuperset - else: - is_not_primary_key = _none_set.issubset - - def _instance(row, result): - if translate_row: - ret = translate_row(self, context, row) - if ret is not EXT_CONTINUE: - row = ret - - if polymorphic_on is not None: - discriminator = row[polymorphic_on] - if discriminator is not None: - _instance = polymorphic_instances[discriminator] - if _instance: - return _instance(row, result) - - # determine identity key - if refresh_state: - identitykey = refresh_state.key - if identitykey is None: - # super-rare condition; a refresh is being called - # on a non-instance-key instance; this is meant to only - # occur within a flush() - identitykey = self._identity_key_from_state(refresh_state) - else: - identitykey = identity_key(row) - - instance = session_identity_map.get(identitykey) - if instance is not None: - state = attributes.instance_state(instance) - dict_ = attributes.instance_dict(instance) - - isnew = state.runid != context.runid - currentload = not isnew - loaded_instance = False - - if not currentload and \ - version_id_col is not None and \ - context.version_check and \ - self._get_state_attr_by_column( - state, - self.version_id_col) != row[version_id_col]: - - raise orm_exc.ConcurrentModificationError( - "Instance '%s' version of %s does not match %s" - % (state_str(state), - self._get_state_attr_by_column(state, self.version_id_col), - row[version_id_col])) - elif refresh_state: - # out of band refresh_state detected (i.e. its not in the session.identity_map) - # honor it anyway. this can happen if a _get() occurs within save_obj(), such as - # when eager_defaults is True. - state = refresh_state - instance = state.obj() - dict_ = attributes.instance_dict(instance) - isnew = state.runid != context.runid - currentload = True - loaded_instance = False - else: - # check for non-NULL values in the primary key columns, - # else no entity is returned for the row - if is_not_primary_key(identitykey[1]): - return None - - isnew = True - currentload = True - loaded_instance = True - - if create_instance: - instance = create_instance(self, context, row, self.class_) - if instance is EXT_CONTINUE: - instance = self.class_manager.new_instance() - else: - manager = attributes.manager_of_class(instance.__class__) - # TODO: if manager is None, raise a friendly error about - # returning instances of unmapped types - manager.setup_instance(instance) - else: - instance = self.class_manager.new_instance() - - dict_ = attributes.instance_dict(instance) - state = attributes.instance_state(instance) - state.key = identitykey - - # manually adding instance to session. for a complete add, - # session._finalize_loaded() must be called. - state.session_id = context.session.hash_key - session_identity_map.add(state) - - if currentload or populate_existing: - if isnew: - state.runid = context.runid - context.progress[state] = dict_ - - if not populate_instance or \ - populate_instance(self, context, row, instance, - only_load_props=only_load_props, - instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - populate_state(state, dict_, row, isnew, only_load_props) - - else: - # populate attributes on non-loading instances which have been expired - # TODO: apply eager loads to un-lazy loaded collections ? - if state in context.partials or state.unloaded: - - if state in context.partials: - isnew = False - (d_, attrs) = context.partials[state] - else: - isnew = True - attrs = state.unloaded - # allow query.instances to commit the subset of attrs - context.partials[state] = (dict_, attrs) - - if not populate_instance or \ - populate_instance(self, context, row, instance, - only_load_props=attrs, - instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - populate_state(state, dict_, row, isnew, attrs) - - if loaded_instance: - state._run_on_load(instance) - - if result is not None and \ - (not append_result or - append_result(self, context, row, instance, - result, instancekey=identitykey, isnew=isnew) - is EXT_CONTINUE): - result.append(instance) - - return instance - return _instance - - def _populators(self, context, path, row, adapter): - """Produce a collection of attribute level row processor callables.""" - - new_populators, existing_populators = [], [] - for prop in self._props.itervalues(): - newpop, existingpop = prop.create_row_processor(context, path, self, row, adapter) - if newpop: - new_populators.append((prop.key, newpop)) - if existingpop: - existing_populators.append((prop.key, existingpop)) - return new_populators, existing_populators - - def _configure_subclass_mapper(self, context, path, adapter): - """Produce a mapper level row processor callable factory for mappers inheriting this one.""" - - def configure_subclass_mapper(discriminator): - try: - mapper = self.polymorphic_map[discriminator] - except KeyError: - raise AssertionError("No such polymorphic_identity %r is defined" % discriminator) - if mapper is self: - return None - - # replace the tip of the path info with the subclass mapper being used. - # that way accurate "load_path" info is available for options - # invoked during deferred loads. - # we lose AliasedClass path elements this way, but currently, - # those are not needed at this stage. - - # this asserts to true - #assert mapper.isa(_class_to_mapper(path[-1])) - - return mapper._instance_processor(context, path[0:-1] + (mapper,), - adapter, polymorphic_from=self) - return configure_subclass_mapper - -log.class_logger(Mapper) + _already_compiling = True + try: + + # double-check inside mutex + if not Mapper._new_mappers: + return + + Mapper.dispatch._for_class(Mapper).before_configured() + # initialize properties on all mappers + # note that _mapper_registry is unordered, which + # may randomly conceal/reveal issues related to + # the order of mapper compilation + + for mapper in list(_mapper_registry): + if getattr(mapper, '_configure_failed', False): + e = sa_exc.InvalidRequestError( + "One or more mappers failed to initialize - " + "can't proceed with initialization of other " + "mappers. Triggering mapper: '%s'. " + "Original exception was: %s" + % (mapper, mapper._configure_failed)) + e._configure_failed = mapper._configure_failed + raise e + if not mapper.configured: + try: + mapper._post_configure_properties() + mapper._expire_memoizations() + mapper.dispatch.mapper_configured( + mapper, mapper.class_) + except Exception: + exc = sys.exc_info()[1] + if not hasattr(exc, '_configure_failed'): + mapper._configure_failed = exc + raise + + Mapper._new_mappers = False + finally: + _already_compiling = False + finally: + _CONFIGURE_MUTEX.release() + Mapper.dispatch._for_class(Mapper).after_configured() def reconstructor(fn): @@ -1873,86 +2900,108 @@ def reconstructor(fn): fn.__sa_reconstructor__ = True return fn -def validates(*names): - """Decorate a method as a 'validator' for one or more named properties. + +def validates(*names, **kw): + r"""Decorate a method as a 'validator' for one or more named properties. Designates a method as a validator, a method which receives the name of the attribute as well as a value to be assigned, or in the - case of a collection to be added to the collection. The function - can then raise validation exceptions to halt the process from continuing, - or can modify or replace the value before proceeding. The function - should otherwise return the given value. + case of a collection, the value to be added to the collection. + The function can then raise validation exceptions to halt the + process from continuing (where Python's built-in ``ValueError`` + and ``AssertionError`` exceptions are reasonable choices), or can + modify or replace the value before proceeding. The function should + otherwise return the given value. + + Note that a validator for a collection **cannot** issue a load of that + collection within the validation routine - this usage raises + an assertion to avoid recursion overflows. This is a reentrant + condition which is not supported. + + :param \*names: list of attribute names to be validated. + :param include_removes: if True, "remove" events will be + sent as well - the validation function must accept an additional + argument "is_remove" which will be a boolean. + + .. versionadded:: 0.7.7 + :param include_backrefs: defaults to ``True``; if ``False``, the + validation function will not emit if the originator is an attribute + event related via a backref. This can be used for bi-directional + :func:`.validates` usage where only one validator should emit per + attribute operation. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`simple_validators` - usage examples for :func:`.validates` """ + include_removes = kw.pop('include_removes', False) + include_backrefs = kw.pop('include_backrefs', True) + def wrap(fn): fn.__sa_validators__ = names + fn.__sa_validation_opts__ = { + "include_removes": include_removes, + "include_backrefs": include_backrefs + } return fn return wrap -def _event_on_init(state, instance, args, kwargs): - """Trigger mapper compilation and run init_instance hooks.""" +def _event_on_load(state, ctx): instrumenting_mapper = state.manager.info[_INSTRUMENTOR] - # compile() always compiles all mappers - instrumenting_mapper.compile() - if 'init_instance' in instrumenting_mapper.extension: - instrumenting_mapper.extension.init_instance( - instrumenting_mapper, instrumenting_mapper.class_, - state.manager.events.original_init, - instance, args, kwargs) + if instrumenting_mapper._reconstructor: + instrumenting_mapper._reconstructor(state.obj()) -def _event_on_init_failure(state, instance, args, kwargs): - """Run init_failed hooks.""" - instrumenting_mapper = state.manager.info[_INSTRUMENTOR] - if 'init_failed' in instrumenting_mapper.extension: - util.warn_exception( - instrumenting_mapper.extension.init_failed, - instrumenting_mapper, instrumenting_mapper.class_, - state.manager.events.original_init, instance, args, kwargs) +def _event_on_first_init(manager, cls): + """Initial mapper compilation trigger. -def _event_on_resurrect(state, instance): - # re-populate the primary key elements - # of the dict based on the mapping. - instrumenting_mapper = state.manager.info[_INSTRUMENTOR] - for col, val in zip(instrumenting_mapper.primary_key, state.key[1]): - instrumenting_mapper._set_state_attr_by_column(state, col, val) - - -def _sort_states(states): - return sorted(states, key=operator.attrgetter('sort_key')) + instrumentation calls this one when InstanceState + is first generated, and is needed for legacy mutable + attributes to work. + """ -def _load_scalar_attributes(state, attribute_names): - """initiate a column-based attribute refresh operation.""" - - mapper = _state_mapper(state) - session = _state_session(state) - if not session: - raise orm_exc.DetachedInstanceError("Instance %s is not bound to a Session; " - "attribute refresh operation cannot proceed" % (state_str(state))) + instrumenting_mapper = manager.info.get(_INSTRUMENTOR) + if instrumenting_mapper: + if Mapper._new_mappers: + configure_mappers() - has_key = _state_has_identity(state) - result = False - if mapper.inherits and not mapper.concrete: - statement = mapper._optimized_get_statement(state, attribute_names) - if statement is not None: - result = session.query(mapper).from_statement(statement).\ - _get(None, - only_load_props=attribute_names, - refresh_state=state) +def _event_on_init(state, args, kwargs): + """Run init_instance hooks. - if result is False: - if has_key: - identity_key = state.key - else: - identity_key = mapper._identity_key_from_state(state) - result = session.query(mapper)._get( - identity_key, - refresh_state=state, - only_load_props=attribute_names) + This also includes mapper compilation, normally not needed + here but helps with some piecemeal configuration + scenarios (such as in the ORM tutorial). - # if instance is pending, a refresh operation - # may not complete (even if PK attributes are assigned) - if has_key and result is None: - raise orm_exc.ObjectDeletedError("Instance '%s' has been deleted." % state_str(state)) + """ + + instrumenting_mapper = state.manager.info.get(_INSTRUMENTOR) + if instrumenting_mapper: + if Mapper._new_mappers: + configure_mappers() + if instrumenting_mapper._set_polymorphic_identity: + instrumenting_mapper._set_polymorphic_identity(state) + + +class _ColumnMapping(dict): + """Error reporting helper for mapper._columntoproperty.""" + + __slots__ = 'mapper', + + def __init__(self, mapper): + self.mapper = mapper + + def __missing__(self, column): + prop = self.mapper._props.get(column) + if prop: + raise orm_exc.UnmappedColumnError( + "Column '%s.%s' is not available, due to " + "conflicting property '%s':%r" % ( + column.table.name, column.name, column.key, prop)) + raise orm_exc.UnmappedColumnError( + "No column %s is configured on mapper %s..." % + (column, self.mapper)) diff --git a/sqlalchemy/orm/properties.py b/sqlalchemy/orm/properties.py index 80d101b..63e7e1e 100644 --- a/sqlalchemy/orm/properties.py +++ b/sqlalchemy/orm/properties.py @@ -1,91 +1,191 @@ -# properties.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# orm/properties.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """MapperProperty implementations. -This is a private module which defines the behavior of invidual ORM-mapped -attributes. +This is a private module which defines the behavior of invidual ORM- +mapped attributes. """ +from __future__ import absolute_import -from sqlalchemy import sql, util, log -import sqlalchemy.exceptions as sa_exc -from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, join_condition -from sqlalchemy.sql import operators, expression -from sqlalchemy.orm import ( - attributes, dependency, mapper, object_mapper, strategies, - ) -from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, _orm_annotate, _orm_deannotate -from sqlalchemy.orm.interfaces import ( - MANYTOMANY, MANYTOONE, MapperProperty, ONETOMANY, PropComparator, - StrategizedProperty, - ) -NoneType = type(None) +from .. import util, log +from ..sql import expression +from . import attributes +from .util import _orm_full_deannotate -__all__ = ('ColumnProperty', 'CompositeProperty', 'SynonymProperty', - 'ComparableProperty', 'RelationshipProperty', 'RelationProperty', 'BackRef') +from .interfaces import PropComparator, StrategizedProperty + +__all__ = ['ColumnProperty', 'CompositeProperty', 'SynonymProperty', + 'ComparableProperty', 'RelationshipProperty'] +@log.class_logger class ColumnProperty(StrategizedProperty): - """Describes an object attribute that corresponds to a table column.""" + """Describes an object attribute that corresponds to a table column. + + Public constructor is the :func:`.orm.column_property` function. + + """ + + strategy_wildcard_key = 'column' + + __slots__ = ( + '_orig_columns', 'columns', 'group', 'deferred', + 'instrument', 'comparator_factory', 'descriptor', 'extension', + 'active_history', 'expire_on_flush', 'info', 'doc', + 'strategy_key', '_creation_order', '_is_polymorphic_discriminator', + '_mapped_by_synonym', '_deferred_column_loader') def __init__(self, *columns, **kwargs): - """Construct a ColumnProperty. + r"""Provide a column-level property for use with a Mapper. - :param \*columns: The list of `columns` describes a single - object property. If there are multiple tables joined - together for the mapper, this list represents the equivalent - column as it appears across each table. + Column-based properties can normally be applied to the mapper's + ``properties`` dictionary using the :class:`.Column` element directly. + Use this function when the given column is not directly present within + the mapper's selectable; examples include SQL expressions, functions, + and scalar SELECT queries. + + Columns that aren't present in the mapper's selectable won't be + persisted by the mapper and are effectively "read-only" attributes. + + :param \*cols: + list of Column objects to be mapped. + + :param active_history=False: + When ``True``, indicates that the "previous" value for a + scalar attribute should be loaded when replaced, if not + already loaded. Normally, history tracking logic for + simple non-primary-key scalar values only needs to be + aware of the "new" value in order to perform a flush. This + flag is available for applications that make use of + :func:`.attributes.get_history` or :meth:`.Session.is_modified` + which also need to know + the "previous" value of the attribute. + + .. versionadded:: 0.6.6 + + :param comparator_factory: a class which extends + :class:`.ColumnProperty.Comparator` which provides custom SQL + clause generation for comparison operations. :param group: + a group name for this property when marked as deferred. :param deferred: + when True, the column property is "deferred", meaning that + it does not load immediately, and is instead loaded when the + attribute is first accessed on an instance. See also + :func:`~sqlalchemy.orm.deferred`. - :param comparator_factory: + :param doc: + optional string that will be applied as the doc on the + class-bound descriptor. - :param descriptor: + :param expire_on_flush=True: + Disable expiry on flush. A column_property() which refers + to a SQL expression (and not a single table-bound column) + is considered to be a "read only" property; populating it + has no effect on the state of data, and it can only return + database state. For this reason a column_property()'s value + is expired whenever the parent object is involved in a + flush, that is, has any kind of "dirty" state within a flush. + Setting this parameter to ``False`` will have the effect of + leaving any existing value present after the flush proceeds. + Note however that the :class:`.Session` with default expiration + settings still expires + all attributes after a :meth:`.Session.commit` call, however. + + .. versionadded:: 0.7.3 + + :param info: Optional data dictionary which will be populated into the + :attr:`.MapperProperty.info` attribute of this object. + + .. versionadded:: 0.8 :param extension: + an + :class:`.AttributeExtension` + instance, or list of extensions, which will be prepended + to the list of attribute listeners for the resulting + descriptor placed on the class. + **Deprecated.** Please see :class:`.AttributeEvents`. """ - self.columns = [expression._labeled(c) for c in columns] + super(ColumnProperty, self).__init__() + self._orig_columns = [expression._labeled(c) for c in columns] + self.columns = [expression._labeled(_orm_full_deannotate(c)) + for c in columns] self.group = kwargs.pop('group', None) self.deferred = kwargs.pop('deferred', False) self.instrument = kwargs.pop('_instrument', True) - self.comparator_factory = kwargs.pop('comparator_factory', self.__class__.Comparator) + self.comparator_factory = kwargs.pop('comparator_factory', + self.__class__.Comparator) self.descriptor = kwargs.pop('descriptor', None) self.extension = kwargs.pop('extension', None) + self.active_history = kwargs.pop('active_history', False) + self.expire_on_flush = kwargs.pop('expire_on_flush', True) + + if 'info' in kwargs: + self.info = kwargs.pop('info') + + if 'doc' in kwargs: + self.doc = kwargs.pop('doc') + else: + for col in reversed(self.columns): + doc = getattr(col, 'doc', None) + if doc is not None: + self.doc = doc + break + else: + self.doc = None + if kwargs: raise TypeError( "%s received unexpected keyword argument(s): %s" % ( - self.__class__.__name__, ', '.join(sorted(kwargs.keys())))) + self.__class__.__name__, + ', '.join(sorted(kwargs.keys())))) util.set_creation_order(self) - if not self.instrument: - self.strategy_class = strategies.UninstrumentedColumnLoader - elif self.deferred: - self.strategy_class = strategies.DeferredColumnLoader - else: - self.strategy_class = strategies.ColumnLoader - + + self.strategy_key = ( + ("deferred", self.deferred), + ("instrument", self.instrument) + ) + + @util.dependencies("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") + def _memoized_attr__deferred_column_loader(self, state, strategies): + return state.InstanceState._instance_level_callable_processor( + self.parent.class_manager, + strategies.LoadDeferredColumns(self.key), self.key) + + @property + def expression(self): + """Return the primary column or expression for this ColumnProperty. + + """ + return self.columns[0] + def instrument_class(self, mapper): if not self.instrument: return - + attributes.register_descriptor( - mapper.class_, - self.key, - comparator=self.comparator_factory(self, mapper), + mapper.class_, + self.key, + comparator=self.comparator_factory(self, mapper), parententity=mapper, - property_=self - ) - + doc=self.doc + ) + def do_init(self): super(ColumnProperty, self).do_init() - if len(self.columns) > 1 and self.parent.primary_key.issuperset(self.columns): + if len(self.columns) > 1 and \ + set(self.parent.primary_key).issuperset(self.columns): util.warn( ("On mapper %s, primary key column '%s' is being combined " "with distinct primary key column '%s' in attribute '%s'. " @@ -94,1112 +194,84 @@ class ColumnProperty(StrategizedProperty): self.columns[0], self.key)) def copy(self): - return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) + return ColumnProperty( + deferred=self.deferred, + group=self.group, + active_history=self.active_history, + *self.columns) - def getattr(self, state, column): - return state.get_impl(self.key).get(state, state.dict) + def _getcommitted(self, state, dict_, column, + passive=attributes.PASSIVE_OFF): + return state.get_impl(self.key).\ + get_committed_value(state, dict_, passive=passive) - def getcommitted(self, state, column, passive=False): - return state.get_impl(self.key).get_committed_value(state, state.dict, passive=passive) - - def setattr(self, state, value, column): - state.get_impl(self.key).set(state, state.dict, value, None) - - def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): - if self.key in source_dict: + def merge(self, session, source_state, source_dict, dest_state, + dest_dict, load, _recursive, _resolve_conflict_map): + if not self.instrument: + return + elif self.key in source_dict: value = source_dict[self.key] - + if not load: dest_dict[self.key] = value else: impl = dest_state.get_impl(self.key) impl.set(dest_state, dest_dict, value, None) - else: - if self.key not in dest_dict: - dest_state.expire_attributes(dest_dict, [self.key]) - - def get_col_value(self, column, value): - return value + elif dest_state.has_identity and self.key not in dest_dict: + dest_state._expire_attributes( + dest_dict, [self.key], no_loader=True) - class Comparator(PropComparator): - @util.memoized_instancemethod - def __clause_element__(self): + class Comparator(util.MemoizedSlots, PropComparator): + """Produce boolean, comparison, and other operators for + :class:`.ColumnProperty` attributes. + + See the documentation for :class:`.PropComparator` for a brief + overview. + + See also: + + :class:`.PropComparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + + __slots__ = '__clause_element__', 'info' + + def _memoized_method___clause_element__(self): if self.adapter: return self.adapter(self.prop.columns[0]) else: - return self.prop.columns[0]._annotate({"parententity": self.mapper, "parentmapper":self.mapper}) - + # no adapter, so we aren't aliased + # assert self._parententity is self._parentmapper + return self.prop.columns[0]._annotate({ + "parententity": self._parententity, + "parentmapper": self._parententity}) + + def _memoized_attr_info(self): + ce = self.__clause_element__() + try: + return ce.info + except AttributeError: + return self.prop.info + + def _fallback_getattr(self, key): + """proxy attribute access down to the mapped column. + + this allows user-defined comparison methods to be accessed. + """ + return getattr(self.__clause_element__(), key) + def operate(self, op, *other, **kwargs): return op(self.__clause_element__(), *other, **kwargs) def reverse_operate(self, op, other, **kwargs): col = self.__clause_element__() return op(col._bind_param(op, other), col, **kwargs) - - # TODO: legacy..do we need this ? (0.5) - ColumnComparator = Comparator - - def __str__(self): - return str(self.parent.class_.__name__) + "." + self.key - -log.class_logger(ColumnProperty) - -class CompositeProperty(ColumnProperty): - """subclasses ColumnProperty to provide composite type support.""" - - def __init__(self, class_, *columns, **kwargs): - super(CompositeProperty, self).__init__(*columns, **kwargs) - self._col_position_map = util.column_dict((c, i) for i, c in enumerate(columns)) - self.composite_class = class_ - self.strategy_class = strategies.CompositeColumnLoader - - def copy(self): - return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns) - - def do_init(self): - # skip over ColumnProperty's do_init(), - # which issues assertions that do not apply to CompositeColumnProperty - super(ColumnProperty, self).do_init() - - def getattr(self, state, column): - obj = state.get_impl(self.key).get(state, state.dict) - return self.get_col_value(column, obj) - - def getcommitted(self, state, column, passive=False): - # TODO: no coverage here - obj = state.get_impl(self.key).get_committed_value(state, state.dict, passive=passive) - return self.get_col_value(column, obj) - - def setattr(self, state, value, column): - - obj = state.get_impl(self.key).get(state, state.dict) - if obj is None: - obj = self.composite_class(*[None for c in self.columns]) - state.get_impl(self.key).set(state, state.dict, obj, None) - - if hasattr(obj, '__set_composite_values__'): - values = list(obj.__composite_values__()) - values[self._col_position_map[column]] = value - obj.__set_composite_values__(*values) - else: - setattr(obj, column.key, value) - - def get_col_value(self, column, value): - if value is None: - return None - for a, b in zip(self.columns, value.__composite_values__()): - if a is column: - return b - - class Comparator(PropComparator): - def __clause_element__(self): - if self.adapter: - # TODO: test coverage for adapted composite comparison - return expression.ClauseList(*[self.adapter(x) for x in self.prop.columns]) - else: - return expression.ClauseList(*self.prop.columns) - - __hash__ = None - - def __eq__(self, other): - if other is None: - values = [None] * len(self.prop.columns) - else: - values = other.__composite_values__() - return sql.and_(*[a==b for a, b in zip(self.prop.columns, values)]) - - def __ne__(self, other): - return sql.not_(self.__eq__(other)) def __str__(self): return str(self.parent.class_.__name__) + "." + self.key - -class ConcreteInheritedProperty(MapperProperty): - extension = None - - def setup(self, context, entity, path, adapter, **kwargs): - pass - - def create_row_processor(self, selectcontext, path, mapper, row, adapter): - return (None, None) - - def instrument_class(self, mapper): - def warn(): - raise AttributeError("Concrete %s does not implement attribute %r at " - "the instance level. Add this property explicitly to %s." % - (self.parent, self.key, self.parent)) - - class NoninheritedConcreteProp(object): - def __set__(s, obj, value): - warn() - def __delete__(s, obj): - warn() - def __get__(s, obj, owner): - warn() - - comparator_callable = None - # TODO: put this process into a deferred callable? - for m in self.parent.iterate_to_root(): - p = m._get_property(self.key) - if not isinstance(p, ConcreteInheritedProperty): - comparator_callable = p.comparator_factory - break - - attributes.register_descriptor( - mapper.class_, - self.key, - comparator=comparator_callable(self, mapper), - parententity=mapper, - property_=self, - proxy_property=NoninheritedConcreteProp() - ) - - -class SynonymProperty(MapperProperty): - - extension = None - - def __init__(self, name, map_column=None, descriptor=None, comparator_factory=None): - self.name = name - self.map_column = map_column - self.descriptor = descriptor - self.comparator_factory = comparator_factory - util.set_creation_order(self) - - def setup(self, context, entity, path, adapter, **kwargs): - pass - - def create_row_processor(self, selectcontext, path, mapper, row, adapter): - return (None, None) - - def instrument_class(self, mapper): - class_ = self.parent.class_ - - if self.descriptor is None: - class SynonymProp(object): - def __set__(s, obj, value): - setattr(obj, self.name, value) - def __delete__(s, obj): - delattr(obj, self.name) - def __get__(s, obj, owner): - if obj is None: - return s - return getattr(obj, self.name) - - self.descriptor = SynonymProp() - - def comparator_callable(prop, mapper): - def comparator(): - prop = self.parent._get_property(self.key, resolve_synonyms=True) - if self.comparator_factory: - return self.comparator_factory(prop, mapper) - else: - return prop.comparator_factory(prop, mapper) - return comparator - - attributes.register_descriptor( - mapper.class_, - self.key, - comparator=comparator_callable(self, mapper), - parententity=mapper, - property_=self, - proxy_property=self.descriptor - ) - - def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): - pass - -log.class_logger(SynonymProperty) - -class ComparableProperty(MapperProperty): - """Instruments a Python property for use in query expressions.""" - - extension = None - - def __init__(self, comparator_factory, descriptor=None): - self.descriptor = descriptor - self.comparator_factory = comparator_factory - util.set_creation_order(self) - - def instrument_class(self, mapper): - """Set up a proxy to the unmanaged descriptor.""" - - attributes.register_descriptor( - mapper.class_, - self.key, - comparator=self.comparator_factory(self, mapper), - parententity=mapper, - property_=self, - proxy_property=self.descriptor - ) - - def setup(self, context, entity, path, adapter, **kwargs): - pass - - def create_row_processor(self, selectcontext, path, mapper, row, adapter): - return (None, None) - - def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): - pass - - -class RelationshipProperty(StrategizedProperty): - """Describes an object property that holds a single item or list - of items that correspond to a related database table. - """ - - def __init__(self, argument, - secondary=None, primaryjoin=None, - secondaryjoin=None, - foreign_keys=None, - uselist=None, - order_by=False, - backref=None, - back_populates=None, - post_update=False, - cascade=False, extension=None, - viewonly=False, lazy=True, - collection_class=None, passive_deletes=False, - passive_updates=True, remote_side=None, - enable_typechecks=True, join_depth=None, - comparator_factory=None, - single_parent=False, innerjoin=False, - strategy_class=None, _local_remote_pairs=None, query_class=None): - - self.uselist = uselist - self.argument = argument - self.secondary = secondary - self.primaryjoin = primaryjoin - self.secondaryjoin = secondaryjoin - self.post_update = post_update - self.direction = None - self.viewonly = viewonly - self.lazy = lazy - self.single_parent = single_parent - self._foreign_keys = foreign_keys - self.collection_class = collection_class - self.passive_deletes = passive_deletes - self.passive_updates = passive_updates - self.remote_side = remote_side - self.enable_typechecks = enable_typechecks - self.query_class = query_class - self.innerjoin = innerjoin - - self.join_depth = join_depth - self.local_remote_pairs = _local_remote_pairs - self.extension = extension - self.comparator_factory = comparator_factory or RelationshipProperty.Comparator - self.comparator = self.comparator_factory(self, None) - util.set_creation_order(self) - - if strategy_class: - self.strategy_class = strategy_class - elif self.lazy== 'dynamic': - from sqlalchemy.orm import dynamic - self.strategy_class = dynamic.DynaLoader - else: - self.strategy_class = strategies.factory(self.lazy) - - self._reverse_property = set() - - if cascade is not False: - self.cascade = CascadeOptions(cascade) - else: - self.cascade = CascadeOptions("save-update, merge") - - if self.passive_deletes == 'all' and ("delete" in self.cascade or "delete-orphan" in self.cascade): - raise sa_exc.ArgumentError("Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade") - - self.order_by = order_by - - self.back_populates = back_populates - - if self.back_populates: - if backref: - raise sa_exc.ArgumentError("backref and back_populates keyword arguments are mutually exclusive") - self.backref = None - else: - self.backref = backref - - def instrument_class(self, mapper): - attributes.register_descriptor( - mapper.class_, - self.key, - comparator=self.comparator_factory(self, mapper), - parententity=mapper, - property_=self - ) - - class Comparator(PropComparator): - def __init__(self, prop, mapper, of_type=None, adapter=None): - self.prop = prop - self.mapper = mapper - self.adapter = adapter - if of_type: - self._of_type = _class_to_mapper(of_type) - - def adapted(self, adapter): - """Return a copy of this PropComparator which will use the given adaption function - on the local side of generated expressions. - - """ - return self.__class__(self.property, self.mapper, getattr(self, '_of_type', None), adapter) - - @property - def parententity(self): - return self.property.parent - - def __clause_element__(self): - elem = self.property.parent._with_polymorphic_selectable - if self.adapter: - return self.adapter(elem) - else: - return elem - - def operate(self, op, *other, **kwargs): - return op(self, *other, **kwargs) - - def reverse_operate(self, op, other, **kwargs): - return op(self, *other, **kwargs) - - def of_type(self, cls): - return RelationshipProperty.Comparator(self.property, self.mapper, cls, adapter=self.adapter) - - def in_(self, other): - raise NotImplementedError("in_() not yet supported for relationships. For a " - "simple many-to-one, use in_() against the set of foreign key values.") - - __hash__ = None - - def __eq__(self, other): - if isinstance(other, (NoneType, expression._Null)): - if self.property.direction in [ONETOMANY, MANYTOMANY]: - return ~self._criterion_exists() - else: - return _orm_annotate(self.property._optimized_compare(None, adapt_source=self.adapter)) - elif self.property.uselist: - raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.") - else: - return _orm_annotate(self.property._optimized_compare(other, adapt_source=self.adapter)) - - def _criterion_exists(self, criterion=None, **kwargs): - if getattr(self, '_of_type', None): - target_mapper = self._of_type - to_selectable = target_mapper._with_polymorphic_selectable - if self.property._is_self_referential(): - to_selectable = to_selectable.alias() - - single_crit = target_mapper._single_table_criterion - if single_crit is not None: - if criterion is not None: - criterion = single_crit & criterion - else: - criterion = single_crit - else: - to_selectable = None - - if self.adapter: - source_selectable = self.__clause_element__() - else: - source_selectable = None - - pj, sj, source, dest, secondary, target_adapter = \ - self.property._create_joins(dest_polymorphic=True, dest_selectable=to_selectable, source_selectable=source_selectable) - - for k in kwargs: - crit = self.property.mapper.class_manager[k] == kwargs[k] - if criterion is None: - criterion = crit - else: - criterion = criterion & crit - - # annotate the *local* side of the join condition, in the case of pj + sj this - # is the full primaryjoin, in the case of just pj its the local side of - # the primaryjoin. - if sj is not None: - j = _orm_annotate(pj) & sj - else: - j = _orm_annotate(pj, exclude=self.property.remote_side) - - if criterion is not None and target_adapter: - # limit this adapter to annotated only? - criterion = target_adapter.traverse(criterion) - - # only have the "joined left side" of what we return be subject to Query adaption. The right - # side of it is used for an exists() subquery and should not correlate or otherwise reach out - # to anything in the enclosing query. - if criterion is not None: - criterion = criterion._annotate({'_halt_adapt': True}) - - crit = j & criterion - - return sql.exists([1], crit, from_obj=dest).correlate(source) - - def any(self, criterion=None, **kwargs): - if not self.property.uselist: - raise sa_exc.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().") - - return self._criterion_exists(criterion, **kwargs) - - def has(self, criterion=None, **kwargs): - if self.property.uselist: - raise sa_exc.InvalidRequestError("'has()' not implemented for collections. Use any().") - return self._criterion_exists(criterion, **kwargs) - - def contains(self, other, **kwargs): - if not self.property.uselist: - raise sa_exc.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==") - clause = self.property._optimized_compare(other, adapt_source=self.adapter) - - if self.property.secondaryjoin is not None: - clause.negation_clause = self.__negated_contains_or_equals(other) - - return clause - - def __negated_contains_or_equals(self, other): - if self.property.direction == MANYTOONE: - state = attributes.instance_state(other) - strategy = self.property._get_strategy(strategies.LazyLoader) - - def state_bindparam(state, col): - o = state.obj() # strong ref - return lambda: self.property.mapper._get_committed_attr_by_column(o, col) - - def adapt(col): - if self.adapter: - return self.adapter(col) - else: - return col - - if strategy.use_get: - return sql.and_(*[ - sql.or_( - adapt(x) != state_bindparam(state, y), - adapt(x) == None) - for (x, y) in self.property.local_remote_pairs]) - - criterion = sql.and_(*[x==y for (x, y) in zip(self.property.mapper.primary_key, self.property.mapper.primary_key_from_instance(other))]) - return ~self._criterion_exists(criterion) - - def __ne__(self, other): - if isinstance(other, (NoneType, expression._Null)): - if self.property.direction == MANYTOONE: - return sql.or_(*[x!=None for x in self.property._foreign_keys]) - else: - return self._criterion_exists() - elif self.property.uselist: - raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.") - else: - return self.__negated_contains_or_equals(other) - - @util.memoized_property - def property(self): - self.prop.parent.compile() - return self.prop - - def compare(self, op, value, value_is_parent=False, alias_secondary=True): - if op == operators.eq: - if value is None: - if self.uselist: - return ~sql.exists([1], self.primaryjoin) - else: - return self._optimized_compare(None, - value_is_parent=value_is_parent, - alias_secondary=alias_secondary) - else: - return self._optimized_compare(value, - value_is_parent=value_is_parent, - alias_secondary=alias_secondary) - else: - return op(self.comparator, value) - - def _optimized_compare(self, value, value_is_parent=False, - adapt_source=None, alias_secondary=True): - if value is not None: - value = attributes.instance_state(value) - return self._get_strategy(strategies.LazyLoader).\ - lazy_clause(value, - reverse_direction=not value_is_parent, - alias_secondary=alias_secondary, adapt_source=adapt_source) - - def __str__(self): - return str(self.parent.class_.__name__) + "." + self.key - - def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): - if load: - # TODO: no test coverage for recursive check - for r in self._reverse_property: - if (source_state, r) in _recursive: - return - - if not "merge" in self.cascade: - return - - if self.key not in source_dict: - return - - if self.uselist: - instances = source_state.get_impl(self.key).\ - get(source_state, source_dict) - if hasattr(instances, '_sa_adapter'): - # convert collections to adapters to get a true iterator - instances = instances._sa_adapter - - if load: - # for a full merge, pre-load the destination collection, - # so that individual _merge of each item pulls from identity - # map for those already present. - # also assumes CollectionAttrbiuteImpl behavior of loading - # "old" list in any case - dest_state.get_impl(self.key).get(dest_state, dest_dict) - - dest_list = [] - for current in instances: - current_state = attributes.instance_state(current) - current_dict = attributes.instance_dict(current) - _recursive[(current_state, self)] = True - obj = session._merge(current_state, current_dict, load=load, _recursive=_recursive) - if obj is not None: - dest_list.append(obj) - - if not load: - coll = attributes.init_state_collection(dest_state, dest_dict, self.key) - for c in dest_list: - coll.append_without_event(c) - else: - dest_state.get_impl(self.key)._set_iterable(dest_state, dest_dict, dest_list) - else: - current = source_dict[self.key] - if current is not None: - current_state = attributes.instance_state(current) - current_dict = attributes.instance_dict(current) - _recursive[(current_state, self)] = True - obj = session._merge(current_state, current_dict, load=load, _recursive=_recursive) - else: - obj = None - - if not load: - dest_dict[self.key] = obj - else: - dest_state.get_impl(self.key).set(dest_state, dest_dict, obj, None) - - def cascade_iterator(self, type_, state, visited_instances, halt_on=None): - if not type_ in self.cascade: - return - - # only actively lazy load on the 'delete' cascade - if type_ != 'delete' or self.passive_deletes: - passive = attributes.PASSIVE_NO_INITIALIZE - else: - passive = attributes.PASSIVE_OFF - - if type_ == 'save-update': - instances = attributes.get_state_history(state, self.key, passive=passive).sum() - else: - instances = state.value_as_iterable(self.key, passive=passive) - - if instances: - for c in instances: - if c is not None and \ - c is not attributes.PASSIVE_NO_RESULT and \ - c not in visited_instances and \ - (halt_on is None or not halt_on(c)): - - if not isinstance(c, self.mapper.class_): - raise AssertionError("Attribute '%s' on class '%s' " - "doesn't handle objects " - "of type '%s'" % ( - self.key, - str(self.parent.class_), - str(c.__class__) - )) - visited_instances.add(c) - - # cascade using the mapper local to this - # object, so that its individual properties are located - instance_mapper = object_mapper(c) - yield (c, instance_mapper, attributes.instance_state(c)) - - def _add_reverse_property(self, key): - other = self.mapper._get_property(key) - self._reverse_property.add(other) - other._reverse_property.add(self) - - if not other._get_target().common_parent(self.parent): - raise sa_exc.ArgumentError("reverse_property %r on relationship %s references " - "relationship %s, which does not reference mapper %s" % (key, self, other, self.parent)) - - if self.direction in (ONETOMANY, MANYTOONE) and self.direction == other.direction: - raise sa_exc.ArgumentError("%s and back-reference %s are both of the same direction %r." - " Did you mean to set remote_side on the many-to-one side ?" % (other, self, self.direction)) - - def do_init(self): - self._get_target() - self._assert_is_primary() - self._process_dependent_arguments() - self._determine_joins() - self._determine_synchronize_pairs() - self._determine_direction() - self._determine_local_remote_pairs() - self._post_init() - self._generate_backref() - super(RelationshipProperty, self).do_init() - - def _get_target(self): - if not hasattr(self, 'mapper'): - if isinstance(self.argument, type): - self.mapper = mapper.class_mapper(self.argument, compile=False) - elif isinstance(self.argument, mapper.Mapper): - self.mapper = self.argument - elif util.callable(self.argument): - # accept a callable to suit various deferred-configurational schemes - self.mapper = mapper.class_mapper(self.argument(), compile=False) - else: - raise sa_exc.ArgumentError("relationship '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument))) - assert isinstance(self.mapper, mapper.Mapper), self.mapper - return self.mapper - - def _process_dependent_arguments(self): - - # accept callables for other attributes which may require deferred initialization - for attr in ('order_by', 'primaryjoin', 'secondaryjoin', 'secondary', '_foreign_keys', 'remote_side'): - if util.callable(getattr(self, attr)): - setattr(self, attr, getattr(self, attr)()) - - # in the case that InstrumentedAttributes were used to construct - # primaryjoin or secondaryjoin, remove the "_orm_adapt" annotation so these - # interact with Query in the same way as the original Table-bound Column objects - for attr in ('primaryjoin', 'secondaryjoin'): - val = getattr(self, attr) - if val is not None: - util.assert_arg_type(val, sql.ColumnElement, attr) - setattr(self, attr, _orm_deannotate(val)) - - if self.order_by is not False and self.order_by is not None: - self.order_by = [expression._literal_as_column(x) for x in util.to_list(self.order_by)] - - self._foreign_keys = util.column_set(expression._literal_as_column(x) for x in util.to_column_set(self._foreign_keys)) - self.remote_side = util.column_set(expression._literal_as_column(x) for x in util.to_column_set(self.remote_side)) - - if not self.parent.concrete: - for inheriting in self.parent.iterate_to_root(): - if inheriting is not self.parent and inheriting._get_property(self.key, raiseerr=False): - util.warn( - ("Warning: relationship '%s' on mapper '%s' supercedes " - "the same relationship on inherited mapper '%s'; this " - "can cause dependency issues during flush") % - (self.key, self.parent, inheriting)) - - # TODO: remove 'self.table' - self.target = self.table = self.mapper.mapped_table - - if self.cascade.delete_orphan: - if self.parent.class_ is self.mapper.class_: - raise sa_exc.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade " - "rule on a self-referential relationship. " - "You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self))) - self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_)) - - def _determine_joins(self): - if self.secondaryjoin is not None and self.secondary is None: - raise sa_exc.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument") - # if join conditions were not specified, figure them out based on foreign keys - - def _search_for_join(mapper, table): - # find a join between the given mapper's mapped table and the given table. - # will try the mapper's local table first for more specificity, then if not - # found will try the more general mapped table, which in the case of inheritance - # is a join. - try: - return join_condition(mapper.local_table, table) - except sa_exc.ArgumentError, e: - return join_condition(mapper.mapped_table, table) - - try: - if self.secondary is not None: - if self.secondaryjoin is None: - self.secondaryjoin = _search_for_join(self.mapper, self.secondary) - if self.primaryjoin is None: - self.primaryjoin = _search_for_join(self.parent, self.secondary) - else: - if self.primaryjoin is None: - self.primaryjoin = _search_for_join(self.parent, self.target) - except sa_exc.ArgumentError, e: - raise sa_exc.ArgumentError("Could not determine join condition between " - "parent/child tables on relationship %s. " - "Specify a 'primaryjoin' expression. If this is a " - "many-to-many relationship, 'secondaryjoin' is needed as well." % (self)) - - def _col_is_part_of_mappings(self, column): - if self.secondary is None: - return self.parent.mapped_table.c.contains_column(column) or \ - self.target.c.contains_column(column) - else: - return self.parent.mapped_table.c.contains_column(column) or \ - self.target.c.contains_column(column) or \ - self.secondary.c.contains_column(column) is not None - - def _determine_synchronize_pairs(self): - - if self.local_remote_pairs: - if not self._foreign_keys: - raise sa_exc.ArgumentError("foreign_keys argument is required with _local_remote_pairs argument") - - self.synchronize_pairs = [] - - for l, r in self.local_remote_pairs: - if r in self._foreign_keys: - self.synchronize_pairs.append((l, r)) - elif l in self._foreign_keys: - self.synchronize_pairs.append((r, l)) - else: - eq_pairs = criterion_as_pairs( - self.primaryjoin, - consider_as_foreign_keys=self._foreign_keys, - any_operator=self.viewonly - ) - eq_pairs = [ - (l, r) for l, r in eq_pairs if - (self._col_is_part_of_mappings(l) and - self._col_is_part_of_mappings(r)) - or self.viewonly and r in self._foreign_keys - ] - - if not eq_pairs: - if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=True): - raise sa_exc.ArgumentError("Could not locate any equated, locally " - "mapped column pairs for primaryjoin condition '%s' on relationship %s. " - "For more relaxed rules on join conditions, the relationship may be " - "marked as viewonly=True." % (self.primaryjoin, self) - ) - else: - if self._foreign_keys: - raise sa_exc.ArgumentError("Could not determine relationship direction for " - "primaryjoin condition '%s', on relationship %s. " - "Do the columns in 'foreign_keys' represent only the 'foreign' columns " - "in this join condition ?" % (self.primaryjoin, self)) - else: - raise sa_exc.ArgumentError("Could not determine relationship direction for " - "primaryjoin condition '%s', on relationship %s. " - "Specify the 'foreign_keys' argument to indicate which columns " - "on the relationship are foreign." % (self.primaryjoin, self)) - - self.synchronize_pairs = eq_pairs - - if self.secondaryjoin is not None: - sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=self.viewonly) - sq_pairs = [(l, r) for l, r in sq_pairs if (self._col_is_part_of_mappings(l) and self._col_is_part_of_mappings(r)) or r in self._foreign_keys] - - if not sq_pairs: - if not self.viewonly and criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=True): - raise sa_exc.ArgumentError("Could not locate any equated, locally mapped " - "column pairs for secondaryjoin condition '%s' on relationship %s. " - "For more relaxed rules on join conditions, the " - "relationship may be marked as viewonly=True." % (self.secondaryjoin, self) - ) - else: - raise sa_exc.ArgumentError("Could not determine relationship direction " - "for secondaryjoin condition '%s', on relationship %s. " - "Specify the foreign_keys argument to indicate which " - "columns on the relationship are foreign." % (self.secondaryjoin, self)) - - self.secondary_synchronize_pairs = sq_pairs - else: - self.secondary_synchronize_pairs = None - - self._foreign_keys = util.column_set(r for l, r in self.synchronize_pairs) - if self.secondary_synchronize_pairs: - self._foreign_keys.update(r for l, r in self.secondary_synchronize_pairs) - - def _determine_direction(self): - if self.secondaryjoin is not None: - self.direction = MANYTOMANY - - elif self._refers_to_parent_table(): - # self referential defaults to ONETOMANY unless the "remote" side is present - # and does not reference any foreign key columns - - if self.local_remote_pairs: - remote = [r for l, r in self.local_remote_pairs] - elif self.remote_side: - remote = self.remote_side - else: - remote = None - - if not remote or self._foreign_keys.\ - difference(l for l, r in self.synchronize_pairs).\ - intersection(remote): - self.direction = ONETOMANY - else: - self.direction = MANYTOONE - - else: - foreign_keys = [f for c, f in self.synchronize_pairs] - - parentcols = util.column_set(self.parent.mapped_table.c) - targetcols = util.column_set(self.mapper.mapped_table.c) - - # fk collection which suggests ONETOMANY. - onetomany_fk = targetcols.intersection(foreign_keys) - - # fk collection which suggests MANYTOONE. - manytoone_fk = parentcols.intersection(foreign_keys) - - if not onetomany_fk and not manytoone_fk: - raise sa_exc.ArgumentError( - "Can't determine relationship direction for relationship '%s' " - "- foreign key columns are present in neither the " - "parent nor the child's mapped tables" % self ) - - elif onetomany_fk and manytoone_fk: - # fks on both sides. do the same - # test only based on the local side. - referents = [c for c, f in self.synchronize_pairs] - onetomany_local = parentcols.intersection(referents) - manytoone_local = targetcols.intersection(referents) - - if onetomany_local and not manytoone_local: - self.direction = ONETOMANY - elif manytoone_local and not onetomany_local: - self.direction = MANYTOONE - elif onetomany_fk: - self.direction = ONETOMANY - elif manytoone_fk: - self.direction = MANYTOONE - - if not self.direction: - raise sa_exc.ArgumentError( - "Can't determine relationship direction for relationship '%s' " - "- foreign key columns are present in both the parent and " - "the child's mapped tables. Specify 'foreign_keys' " - "argument." % self) - - if self.cascade.delete_orphan and not self.single_parent and \ - (self.direction is MANYTOMANY or self.direction is MANYTOONE): - util.warn("On %s, delete-orphan cascade is not supported on a " - "many-to-many or many-to-one relationship when single_parent is not set. " - " Set single_parent=True on the relationship()." % self) - - def _determine_local_remote_pairs(self): - if not self.local_remote_pairs: - if self.remote_side: - if self.direction is MANYTOONE: - self.local_remote_pairs = [ - (r, l) for l, r in - criterion_as_pairs(self.primaryjoin, consider_as_referenced_keys=self.remote_side, any_operator=True) - ] - else: - self.local_remote_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self.remote_side, any_operator=True) - - if not self.local_remote_pairs: - raise sa_exc.ArgumentError("Relationship %s could not determine any local/remote column pairs from remote side argument %r" % (self, self.remote_side)) - - else: - if self.viewonly: - eq_pairs = self.synchronize_pairs - if self.secondaryjoin is not None: - eq_pairs += self.secondary_synchronize_pairs - else: - eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=True) - if self.secondaryjoin is not None: - eq_pairs += criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self._foreign_keys, any_operator=True) - eq_pairs = [(l, r) for l, r in eq_pairs if self._col_is_part_of_mappings(l) and self._col_is_part_of_mappings(r)] - - if self.direction is MANYTOONE: - self.local_remote_pairs = [(r, l) for l, r in eq_pairs] - else: - self.local_remote_pairs = eq_pairs - elif self.remote_side: - raise sa_exc.ArgumentError("remote_side argument is redundant against more detailed _local_remote_side argument.") - - for l, r in self.local_remote_pairs: - - if self.direction is ONETOMANY and not self._col_is_part_of_mappings(l): - raise sa_exc.ArgumentError("Local column '%s' is not part of mapping %s. " - "Specify remote_side argument to indicate which column " - "lazy join condition should compare against." % (l, self.parent)) - - elif self.direction is MANYTOONE and not self._col_is_part_of_mappings(r): - raise sa_exc.ArgumentError("Remote column '%s' is not part of mapping %s. " - "Specify remote_side argument to indicate which column lazy " - "join condition should bind." % (r, self.mapper)) - - self.local_side, self.remote_side = [util.ordered_column_set(x) for x in zip(*list(self.local_remote_pairs))] - - def _assert_is_primary(self): - if not self.is_primary() and \ - not mapper.class_mapper(self.parent.class_, compile=False)._get_property(self.key, raiseerr=False): - - raise sa_exc.ArgumentError("Attempting to assign a new relationship '%s' to " - "a non-primary mapper on class '%s'. New relationships can only be " - "added to the primary mapper, i.e. the very first " - "mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__)) - - def _generate_backref(self): - if not self.is_primary(): - return - - if self.backref is not None and not self.back_populates: - if isinstance(self.backref, basestring): - backref_key, kwargs = self.backref, {} - else: - backref_key, kwargs = self.backref - - mapper = self.mapper.primary_mapper() - if mapper._get_property(backref_key, raiseerr=False) is not None: - raise sa_exc.ArgumentError("Error creating backref '%s' on relationship '%s': " - "property of that name exists on mapper '%s'" % (backref_key, self, mapper)) - - if self.secondary is not None: - pj = kwargs.pop('primaryjoin', self.secondaryjoin) - sj = kwargs.pop('secondaryjoin', self.primaryjoin) - else: - pj = kwargs.pop('primaryjoin', self.primaryjoin) - sj = kwargs.pop('secondaryjoin', None) - if sj: - raise sa_exc.InvalidRequestError( - "Can't assign 'secondaryjoin' on a backref against " - "a non-secondary relationship.") - - foreign_keys = kwargs.pop('foreign_keys', self._foreign_keys) - - parent = self.parent.primary_mapper() - kwargs.setdefault('viewonly', self.viewonly) - kwargs.setdefault('post_update', self.post_update) - - self.back_populates = backref_key - relationship = RelationshipProperty( - parent, - self.secondary, - pj, - sj, - foreign_keys=foreign_keys, - back_populates=self.key, - **kwargs) - - mapper._configure_property(backref_key, relationship) - - - if self.back_populates: - self.extension = list(util.to_list(self.extension, default=[])) - self.extension.append(attributes.GenericBackrefExtension(self.back_populates)) - self._add_reverse_property(self.back_populates) - - - def _post_init(self): - self.logger.info("%s setup primary join %s", self, self.primaryjoin) - self.logger.info("%s setup secondary join %s", self, self.secondaryjoin) - self.logger.info("%s synchronize pairs [%s]", self, ",".join("(%s => %s)" % (l, r) for l, r in self.synchronize_pairs)) - self.logger.info("%s secondary synchronize pairs [%s]", self, ",".join(("(%s => %s)" % (l, r) for l, r in self.secondary_synchronize_pairs or []))) - self.logger.info("%s local/remote pairs [%s]", self, ",".join("(%s / %s)" % (l, r) for l, r in self.local_remote_pairs)) - self.logger.info("%s relationship direction %s", self, self.direction) - - if self.uselist is None: - self.uselist = self.direction is not MANYTOONE - - if not self.viewonly: - self._dependency_processor = dependency.create_dependency_processor(self) - - def _refers_to_parent_table(self): - for c, f in self.synchronize_pairs: - if c.table is f.table: - return True - else: - return False - - def _is_self_referential(self): - return self.mapper.common_parent(self.parent) - - def _create_joins(self, source_polymorphic=False, source_selectable=None, dest_polymorphic=False, dest_selectable=None, of_type=None): - if source_selectable is None: - if source_polymorphic and self.parent.with_polymorphic: - source_selectable = self.parent._with_polymorphic_selectable - - aliased = False - if dest_selectable is None: - if dest_polymorphic and self.mapper.with_polymorphic: - dest_selectable = self.mapper._with_polymorphic_selectable - aliased = True - else: - dest_selectable = self.mapper.mapped_table - - if self._is_self_referential() and source_selectable is None: - dest_selectable = dest_selectable.alias() - aliased = True - else: - aliased = True - - aliased = aliased or (source_selectable is not None) - - primaryjoin, secondaryjoin, secondary = self.primaryjoin, self.secondaryjoin, self.secondary - - # adjust the join condition for single table inheritance, - # in the case that the join is to a subclass - # this is analgous to the "_adjust_for_single_table_inheritance()" - # method in Query. - - dest_mapper = of_type or self.mapper - - single_crit = dest_mapper._single_table_criterion - if single_crit is not None: - if secondaryjoin is not None: - secondaryjoin = secondaryjoin & single_crit - else: - primaryjoin = primaryjoin & single_crit - - - if aliased: - if secondary is not None: - secondary = secondary.alias() - primary_aliasizer = ClauseAdapter(secondary) - if dest_selectable is not None: - secondary_aliasizer = ClauseAdapter(dest_selectable, equivalents=self.mapper._equivalent_columns).chain(primary_aliasizer) - else: - secondary_aliasizer = primary_aliasizer - - if source_selectable is not None: - primary_aliasizer = ClauseAdapter(secondary).chain(ClauseAdapter(source_selectable, equivalents=self.parent._equivalent_columns)) - - secondaryjoin = secondary_aliasizer.traverse(secondaryjoin) - else: - if dest_selectable is not None: - primary_aliasizer = ClauseAdapter(dest_selectable, exclude=self.local_side, equivalents=self.mapper._equivalent_columns) - if source_selectable is not None: - primary_aliasizer.chain(ClauseAdapter(source_selectable, exclude=self.remote_side, equivalents=self.parent._equivalent_columns)) - elif source_selectable is not None: - primary_aliasizer = ClauseAdapter(source_selectable, exclude=self.remote_side, equivalents=self.parent._equivalent_columns) - - secondary_aliasizer = None - - primaryjoin = primary_aliasizer.traverse(primaryjoin) - target_adapter = secondary_aliasizer or primary_aliasizer - target_adapter.include = target_adapter.exclude = None - else: - target_adapter = None - - if source_selectable is None: - source_selectable = self.parent.local_table - - if dest_selectable is None: - dest_selectable = self.mapper.local_table - - return (primaryjoin, secondaryjoin, - source_selectable, - dest_selectable, secondary, target_adapter) - - def register_dependencies(self, uowcommit): - if not self.viewonly: - self._dependency_processor.register_dependencies(uowcommit) - - def register_processors(self, uowcommit): - if not self.viewonly: - self._dependency_processor.register_processors(uowcommit) - -PropertyLoader = RelationProperty = RelationshipProperty -log.class_logger(RelationshipProperty) - -mapper.ColumnProperty = ColumnProperty -mapper.SynonymProperty = SynonymProperty -mapper.ComparableProperty = ComparableProperty -mapper.RelationshipProperty = RelationshipProperty -mapper.ConcreteInheritedProperty = ConcreteInheritedProperty diff --git a/sqlalchemy/orm/query.py b/sqlalchemy/orm/query.py index e98ad89..e8bd717 100644 --- a/sqlalchemy/orm/query.py +++ b/sqlalchemy/orm/query.py @@ -1,92 +1,139 @@ # orm/query.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """The Query class and support. -Defines the :class:`~sqlalchemy.orm.query.Query` class, the central construct used by -the ORM to construct database queries. +Defines the :class:`.Query` class, the central +construct used by the ORM to construct database queries. -The ``Query`` class should not be confused with the -:class:`~sqlalchemy.sql.expression.Select` class, which defines database SELECT -operations at the SQL (non-ORM) level. ``Query`` differs from ``Select`` in -that it returns ORM-mapped objects and interacts with an ORM session, whereas -the ``Select`` construct interacts directly with the database to return -iterable result sets. +The :class:`.Query` class should not be confused with the +:class:`.Select` class, which defines database +SELECT operations at the SQL (non-ORM) level. ``Query`` differs from +``Select`` in that it returns ORM-mapped objects and interacts with an +ORM session, whereas the ``Select`` construct interacts directly with the +database to return iterable result sets. """ from itertools import chain -from operator import itemgetter - -from sqlalchemy import sql, util, log, schema -from sqlalchemy import exc as sa_exc -from sqlalchemy.orm import exc as orm_exc -from sqlalchemy.sql import util as sql_util -from sqlalchemy.sql import expression, visitors, operators -from sqlalchemy.orm import ( - attributes, interfaces, mapper, object_mapper, evaluator, - ) -from sqlalchemy.orm.util import ( - AliasedClass, ORMAdapter, _entity_descriptor, _entity_info, - _is_aliased_class, _is_mapped_class, _orm_columns, _orm_selectable, - join as orm_join, - ) +from . import ( + attributes, interfaces, object_mapper, persistence, + exc as orm_exc, loading +) +from .base import _entity_descriptor, _is_aliased_class, \ + _is_mapped_class, _orm_columns, _generative, InspectionAttr +from .path_registry import PathRegistry +from .util import ( + AliasedClass, ORMAdapter, join as orm_join, with_parent, aliased +) +from .. import sql, util, log, exc as sa_exc, inspect, inspection +from ..sql.expression import _interpret_as_from +from ..sql import ( + util as sql_util, + expression, visitors +) +from ..sql.base import ColumnCollection +from . import properties __all__ = ['Query', 'QueryContext', 'aliased'] -aliased = AliasedClass +_path_registry = PathRegistry.root -def _generative(*assertions): - """Mark a method as generative.""" - - @util.decorator - def generate(fn, *args, **kw): - self = args[0]._clone() - for assertion in assertions: - assertion(self, fn.func_name) - fn(self, *args[1:], **kw) - return self - return generate +@inspection._self_inspects +@log.class_logger class Query(object): - """ORM-level SQL construction object.""" - + """ORM-level SQL construction object. + + :class:`.Query` is the source of all SELECT statements generated by the + ORM, both those formulated by end-user query operations as well as by + high level internal operations such as related collection loading. It + features a generative interface whereby successive calls return a new + :class:`.Query` object, a copy of the former with additional + criteria and options associated with it. + + :class:`.Query` objects are normally initially generated using the + :meth:`~.Session.query` method of :class:`.Session`, and in + less common cases by instantiating the :class:`.Query` directly and + associating with a :class:`.Session` using the :meth:`.Query.with_session` + method. + + For a full walkthrough of :class:`.Query` usage, see the + :ref:`ormtutorial_toplevel`. + + """ + _enable_eagerloads = True _enable_assertions = True _with_labels = False _criterion = None _yield_per = None - _lockmode = None _order_by = False _group_by = False _having = None _distinct = False + _prefixes = None + _suffixes = None _offset = None _limit = None + _for_update_arg = None _statement = None _correlate = frozenset() _populate_existing = False + _invoke_all_eagers = True _version_check = False _autoflush = True - _current_path = () _only_load_props = None _refresh_state = None _from_obj = () + _join_entities = () + _select_from_entity = None + _mapper_adapter_map = {} _filter_aliases = None _from_obj_alias = None - _joinpath = _joinpoint = util.frozendict() - _execution_options = util.frozendict() - _params = util.frozendict() - _attributes = util.frozendict() + _joinpath = _joinpoint = util.immutabledict() + _execution_options = util.immutabledict() + _params = util.immutabledict() + _attributes = util.immutabledict() _with_options = () _with_hints = () - + _enable_single_crit = True + _orm_only_adapt = True + _orm_only_from_obj_alias = True + _current_path = _path_registry + _has_mapper_entities = False + def __init__(self, entities, session=None): + """Construct a :class:`.Query` directly. + + E.g.:: + + q = Query([User, Address], session=some_session) + + The above is equivalent to:: + + q = some_session.query(User, Address) + + :param entities: a sequence of entities and/or SQL expressions. + + :param session: a :class:`.Session` with which the :class:`.Query` + will be associated. Optional; a :class:`.Query` can be associated + with a :class:`.Session` generatively via the + :meth:`.Query.with_session` method as well. + + .. seealso:: + + :meth:`.Session.query` + + :meth:`.Query.with_session` + + """ self.session = session self._polymorphic_adapters = {} self._set_entities(entities) @@ -95,193 +142,236 @@ class Query(object): if entity_wrapper is None: entity_wrapper = _QueryEntity self._entities = [] + self._primary_entity = None + self._has_mapper_entities = False for ent in util.to_list(entities): entity_wrapper(self, ent) - self._setup_aliasizers(self._entities) + self._set_entity_selectables(self._entities) - def _setup_aliasizers(self, entities): - if hasattr(self, '_mapper_adapter_map'): - # usually safe to share a single map, but copying to prevent - # subtle leaks if end-user is reusing base query with arbitrary - # number of aliased() objects - self._mapper_adapter_map = d = self._mapper_adapter_map.copy() - else: - self._mapper_adapter_map = d = {} + def _set_entity_selectables(self, entities): + self._mapper_adapter_map = d = self._mapper_adapter_map.copy() for ent in entities: for entity in ent.entities: if entity not in d: - mapper, selectable, is_aliased_class = _entity_info(entity) - if not is_aliased_class and mapper.with_polymorphic: - with_polymorphic = mapper._with_polymorphic_mappers - if mapper.mapped_table not in self._polymorphic_adapters: - self.__mapper_loads_polymorphically_with(mapper, - sql_util.ColumnAdapter(selectable, mapper._equivalent_columns)) - adapter = None - elif is_aliased_class: - adapter = sql_util.ColumnAdapter(selectable, mapper._equivalent_columns) - with_polymorphic = None + ext_info = inspect(entity) + if not ext_info.is_aliased_class and \ + ext_info.mapper.with_polymorphic: + if ext_info.mapper.mapped_table not in \ + self._polymorphic_adapters: + self._mapper_loads_polymorphically_with( + ext_info.mapper, + sql_util.ColumnAdapter( + ext_info.selectable, + ext_info.mapper._equivalent_columns + ) + ) + aliased_adapter = None + elif ext_info.is_aliased_class: + aliased_adapter = ext_info._adapter else: - with_polymorphic = adapter = None + aliased_adapter = None - d[entity] = (mapper, adapter, selectable, is_aliased_class, with_polymorphic) - ent.setup_entity(entity, *d[entity]) + d[entity] = ( + ext_info, + aliased_adapter + ) + ent.setup_entity(*d[entity]) - def __mapper_loads_polymorphically_with(self, mapper, adapter): - for m2 in mapper._with_polymorphic_mappers: + def _mapper_loads_polymorphically_with(self, mapper, adapter): + for m2 in mapper._with_polymorphic_mappers or [mapper]: self._polymorphic_adapters[m2] = adapter for m in m2.iterate_to_root(): - self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter - - def _set_select_from(self, *obj): + self._polymorphic_adapters[m.local_table] = adapter + def _set_select_from(self, obj, set_base_alias): fa = [] + select_from_alias = None + for from_obj in obj: - if isinstance(from_obj, expression._SelectBaseMixin): - from_obj = from_obj.alias() - fa.append(from_obj) + info = inspect(from_obj) + if hasattr(info, 'mapper') and \ + (info.is_mapper or info.is_aliased_class): + self._select_from_entity = info + if set_base_alias and not info.is_aliased_class: + raise sa_exc.ArgumentError( + "A selectable (FromClause) instance is " + "expected when the base alias is being set.") + fa.append(info.selectable) + elif not info.is_selectable: + raise sa_exc.ArgumentError( + "argument is not a mapped class, mapper, " + "aliased(), or FromClause instance.") + else: + if isinstance(from_obj, expression.SelectBase): + from_obj = from_obj.alias() + if set_base_alias: + select_from_alias = from_obj + fa.append(from_obj) self._from_obj = tuple(fa) - if len(self._from_obj) == 1 and \ - isinstance(self._from_obj[0], expression.Alias): + if set_base_alias and \ + len(self._from_obj) == 1 and \ + isinstance(select_from_alias, expression.Alias): equivs = self.__all_equivs() - self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj[0], equivs) - - def _get_polymorphic_adapter(self, entity, selectable): - self.__mapper_loads_polymorphically_with(entity.mapper, - sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns)) + self._from_obj_alias = sql_util.ColumnAdapter( + self._from_obj[0], equivs) + elif set_base_alias and \ + len(self._from_obj) == 1 and \ + hasattr(info, "mapper") and \ + info.is_aliased_class: + self._from_obj_alias = info._adapter def _reset_polymorphic_adapter(self, mapper): for m2 in mapper._with_polymorphic_mappers: self._polymorphic_adapters.pop(m2, None) for m in m2.iterate_to_root(): - self._polymorphic_adapters.pop(m.mapped_table, None) self._polymorphic_adapters.pop(m.local_table, None) - def __adapt_polymorphic_element(self, element): + def _adapt_polymorphic_element(self, element): + if "parententity" in element._annotations: + search = element._annotations['parententity'] + alias = self._polymorphic_adapters.get(search, None) + if alias: + return alias.adapt_clause(element) + if isinstance(element, expression.FromClause): search = element elif hasattr(element, 'table'): search = element.table else: - search = None + return None - if search is not None: - alias = self._polymorphic_adapters.get(search, None) - if alias: - return alias.adapt_clause(element) + alias = self._polymorphic_adapters.get(search, None) + if alias: + return alias.adapt_clause(element) - def __replace_element(self, adapters): - def replace(elem): - if '_halt_adapt' in elem._annotations: - return elem - - for adapter in adapters: - e = adapter(elem) - if e is not None: - return e - return replace - - def __replace_orm_element(self, adapters): - def replace(elem): - if '_halt_adapt' in elem._annotations: - return elem - - if "_orm_adapt" in elem._annotations or "parententity" in elem._annotations: - for adapter in adapters: - e = adapter(elem) - if e is not None: - return e - return replace + def _adapt_col_list(self, cols): + return [ + self._adapt_clause( + expression._literal_as_label_reference(o), + True, True) + for o in cols + ] @_generative() def _adapt_all_clauses(self): - self._disable_orm_filtering = True - - def _adapt_col_list(self, cols): - return [ - self._adapt_clause(expression._literal_as_text(o), True, True) - for o in cols - ] - + self._orm_only_adapt = False + def _adapt_clause(self, clause, as_filter, orm_only): + """Adapt incoming clauses to transformations which + have been applied within this query.""" + adapters = [] + # do we adapt all expression elements or only those + # tagged as 'ORM' constructs ? + if not self._orm_only_adapt: + orm_only = False + if as_filter and self._filter_aliases: for fa in self._filter_aliases._visitor_iterator: - adapters.append(fa.replace) + adapters.append( + ( + orm_only, fa.replace + ) + ) if self._from_obj_alias: - adapters.append(self._from_obj_alias.replace) + # for the "from obj" alias, apply extra rule to the + # 'ORM only' check, if this query were generated from a + # subquery of itself, i.e. _from_selectable(), apply adaption + # to all SQL constructs. + adapters.append( + ( + orm_only if self._orm_only_from_obj_alias else False, + self._from_obj_alias.replace + ) + ) if self._polymorphic_adapters: - adapters.append(self.__adapt_polymorphic_element) + adapters.append( + ( + orm_only, self._adapt_polymorphic_element + ) + ) if not adapters: return clause - if getattr(self, '_disable_orm_filtering', not orm_only): - return visitors.replacement_traverse( - clause, - {'column_collections':False}, - self.__replace_element(adapters) - ) - else: - return visitors.replacement_traverse( - clause, - {'column_collections':False}, - self.__replace_orm_element(adapters) - ) + def replace(elem): + for _orm_only, adapter in adapters: + # if 'orm only', look for ORM annotations + # in the element before adapting. + if not _orm_only or \ + '_orm_adapt' in elem._annotations or \ + "parententity" in elem._annotations: - def _entity_zero(self): + e = adapter(elem) + if e is not None: + return e + + return visitors.replacement_traverse( + clause, + {}, + replace + ) + + def _query_entity_zero(self): + """Return the first QueryEntity.""" return self._entities[0] def _mapper_zero(self): - return self._entity_zero().entity_zero + """return the Mapper associated with the first QueryEntity.""" + return self._entities[0].mapper - def _extension_zero(self): - ent = self._entity_zero() - return getattr(ent, 'extension', ent.mapper.extension) + def _entity_zero(self): + """Return the 'entity' (mapper or AliasedClass) associated + with the first QueryEntity, or alternatively the 'select from' + entity if specified.""" + + return self._select_from_entity \ + if self._select_from_entity is not None \ + else self._query_entity_zero().entity_zero @property def _mapper_entities(self): - # TODO: this is wrong, its hardcoded to "priamry entity" when - # for the case of __all_equivs() it should not be - # the name of this accessor is wrong too for ent in self._entities: - if hasattr(ent, 'primary_entity'): + if isinstance(ent, _MapperEntity): yield ent def _joinpoint_zero(self): - return self._joinpoint.get('_joinpoint_entity', self._entity_zero().entity_zero) + return self._joinpoint.get( + '_joinpoint_entity', + self._entity_zero() + ) - def _mapper_zero_or_none(self): - if not getattr(self._entities[0], 'primary_entity', False): - return None - return self._entities[0].mapper + def _bind_mapper(self): + ezero = self._entity_zero() + if ezero is not None: + insp = inspect(ezero) + if not insp.is_clause_element: + return insp.mapper - def _only_mapper_zero(self, rationale=None): - if len(self._entities) > 1: + return None + + def _only_full_mapper_zero(self, methname): + if self._entities != [self._primary_entity]: raise sa_exc.InvalidRequestError( - rationale or "This operation requires a Query against a single mapper." - ) - return self._mapper_zero() + "%s() can only be used against " + "a single mapped class." % methname) + return self._primary_entity.entity_zero def _only_entity_zero(self, rationale=None): if len(self._entities) > 1: raise sa_exc.InvalidRequestError( - rationale or "This operation requires a Query against a single mapper." - ) + rationale or + "This operation requires a Query " + "against a single mapper." + ) return self._entity_zero() - def _generate_mapper_zero(self): - if not getattr(self._entities[0], 'primary_entity', False): - raise sa_exc.InvalidRequestError("No primary mapper set up for this Query.") - entity = self._entities[0]._clone() - self._entities = [entity] + self._entities[1:] - return entity - def __all_equivs(self): equivs = {} for ent in self._mapper_entities: @@ -289,18 +379,26 @@ class Query(object): return equivs def _get_condition(self): - self._order_by = self._distinct = False - return self._no_criterion_condition("get") - - def _no_criterion_condition(self, meth): + return self._no_criterion_condition( + "get", order_by=False, distinct=False) + + def _get_existing_condition(self): + self._no_criterion_assertion("get", order_by=False, distinct=False) + + def _no_criterion_assertion(self, meth, order_by=True, distinct=True): if not self._enable_assertions: return - if self._criterion is not None or self._statement is not None or self._from_obj or \ + if self._criterion is not None or \ + self._statement is not None or self._from_obj or \ self._limit is not None or self._offset is not None or \ - self._group_by or self._order_by or self._distinct: + self._group_by or (order_by and self._order_by) or \ + (distinct and self._distinct): raise sa_exc.InvalidRequestError( - "Query.%s() being called on a " - "Query with existing criterion. " % meth) + "Query.%s() being called on a " + "Query with existing criterion. " % meth) + + def _no_criterion_condition(self, meth, order_by=True, distinct=True): + self._no_criterion_assertion(meth, order_by, distinct) self._from_obj = () self._statement = self._criterion = None @@ -311,14 +409,14 @@ class Query(object): return if self._order_by: raise sa_exc.InvalidRequestError( - "Query.%s() being called on a " - "Query with existing criterion. " % meth) + "Query.%s() being called on a " + "Query with existing criterion. " % meth) self._no_criterion_condition(meth) def _no_statement_condition(self, meth): if not self._enable_assertions: return - if self._statement: + if self._statement is not None: raise sa_exc.InvalidRequestError( ("Query.%s() being called on a Query with an existing full " "statement - can't apply criterion.") % meth) @@ -328,30 +426,18 @@ class Query(object): return if self._limit is not None or self._offset is not None: raise sa_exc.InvalidRequestError( - "Query.%s() being called on a Query which already has LIMIT or OFFSET applied. " - "To modify the row-limited results of a Query, call from_self() first. " - "Otherwise, call %s() before limit() or offset() are applied." % (meth, meth) + "Query.%s() being called on a Query which already has LIMIT " + "or OFFSET applied. To modify the row-limited results of a " + " Query, call from_self() first. " + "Otherwise, call %s() before limit() or offset() " + "are applied." + % (meth, meth) ) - def _no_select_modifiers(self, meth): - if not self._enable_assertions: - return - for attr, methname, notset in ( - ('_limit', 'limit()', None), - ('_offset', 'offset()', None), - ('_order_by', 'order_by()', False), - ('_group_by', 'group_by()', False), - ('_distinct', 'distinct()', False), - ): - if getattr(self, attr) is not notset: - raise sa_exc.InvalidRequestError( - "Can't call Query.%s() when %s has been called" % (meth, methname) - ) - - def _get_options(self, populate_existing=None, - version_check=None, - only_load_props=None, - refresh_state=None): + def _get_options(self, populate_existing=None, + version_check=None, + only_load_props=None, + refresh_state=None): if populate_existing: self._populate_existing = populate_existing if version_check: @@ -371,35 +457,154 @@ class Query(object): @property def statement(self): """The full SELECT statement represented by this Query. - + The statement by default will not have disambiguating labels applied to the construct unless with_labels(True) is called first. - + """ - return self._compile_context(labels=self._with_labels).\ - statement._annotate({'_halt_adapt': True}) + stmt = self._compile_context(labels=self._with_labels).\ + statement + if self._params: + stmt = stmt.params(self._params) - def subquery(self): - """return the full SELECT statement represented by this Query, - embedded within an Alias. + # TODO: there's no tests covering effects of + # the annotation not being there + return stmt._annotate({'no_replacement_traverse': True}) + + def subquery(self, name=None, with_labels=False, reduce_columns=False): + """return the full SELECT statement represented by + this :class:`.Query`, embedded within an :class:`.Alias`. Eager JOIN generation within the query is disabled. - The statement by default will not have disambiguating labels - applied to the construct unless with_labels(True) is called - first. + :param name: string name to be assigned as the alias; + this is passed through to :meth:`.FromClause.alias`. + If ``None``, a name will be deterministically generated + at compile time. + + :param with_labels: if True, :meth:`.with_labels` will be called + on the :class:`.Query` first to apply table-qualified labels + to all columns. + + :param reduce_columns: if True, :meth:`.Select.reduce_columns` will + be called on the resulting :func:`.select` construct, + to remove same-named columns where one also refers to the other + via foreign key or WHERE clause equivalence. + + .. versionchanged:: 0.8 the ``with_labels`` and ``reduce_columns`` + keyword arguments were added. """ - return self.enable_eagerloads(False).statement.alias() + q = self.enable_eagerloads(False) + if with_labels: + q = q.with_labels() + q = q.statement + if reduce_columns: + q = q.reduce_columns() + return q.alias(name=name) + + def cte(self, name=None, recursive=False): + r"""Return the full SELECT statement represented by this + :class:`.Query` represented as a common table expression (CTE). + + Parameters and usage are the same as those of the + :meth:`.SelectBase.cte` method; see that method for + further details. + + Here is the `PostgreSQL WITH + RECURSIVE example + `_. + Note that, in this example, the ``included_parts`` cte and the + ``incl_alias`` alias of it are Core selectables, which + means the columns are accessed via the ``.c.`` attribute. The + ``parts_alias`` object is an :func:`.orm.aliased` instance of the + ``Part`` entity, so column-mapped attributes are available + directly:: + + from sqlalchemy.orm import aliased + + class Part(Base): + __tablename__ = 'part' + part = Column(String, primary_key=True) + sub_part = Column(String, primary_key=True) + quantity = Column(Integer) + + included_parts = session.query( + Part.sub_part, + Part.part, + Part.quantity).\ + filter(Part.part=="our part").\ + cte(name="included_parts", recursive=True) + + incl_alias = aliased(included_parts, name="pr") + parts_alias = aliased(Part, name="p") + included_parts = included_parts.union_all( + session.query( + parts_alias.sub_part, + parts_alias.part, + parts_alias.quantity).\ + filter(parts_alias.part==incl_alias.c.sub_part) + ) + + q = session.query( + included_parts.c.sub_part, + func.sum(included_parts.c.quantity). + label('total_quantity') + ).\ + group_by(included_parts.c.sub_part) + + .. seealso:: + + :meth:`.HasCTE.cte` + + """ + return self.enable_eagerloads(False).\ + statement.cte(name=name, recursive=recursive) + + def label(self, name): + """Return the full SELECT statement represented by this + :class:`.Query`, converted + to a scalar subquery with a label of the given name. + + Analogous to :meth:`sqlalchemy.sql.expression.SelectBase.label`. + + .. versionadded:: 0.6.5 + + """ + + return self.enable_eagerloads(False).statement.label(name) + + def as_scalar(self): + """Return the full SELECT statement represented by this + :class:`.Query`, converted to a scalar subquery. + + Analogous to :meth:`sqlalchemy.sql.expression.SelectBase.as_scalar`. + + .. versionadded:: 0.6.5 + + """ + + return self.enable_eagerloads(False).statement.as_scalar() + + @property + def selectable(self): + """Return the :class:`.Select` object emitted by this :class:`.Query`. + + Used for :func:`.inspect` compatibility, this is equivalent to:: + + query.enable_eagerloads(False).with_labels().statement + + """ + return self.__clause_element__() def __clause_element__(self): return self.enable_eagerloads(False).with_labels().statement @_generative() def enable_eagerloads(self, value): - """Control whether or not eager joins and subqueries are + """Control whether or not eager joins and subqueries are rendered. When set to False, the returned Query will not render @@ -410,11 +615,19 @@ class Query(object): This is used primarily when nesting the Query's statement into a subquery or other - selectable. + selectable, or when using :meth:`.Query.yield_per`. """ self._enable_eagerloads = value + def _no_yield_per(self, message): + raise sa_exc.InvalidRequestError( + "The yield_per Query option is currently not " + "compatible with %s eager loading. Please " + "specify lazyload('*') or query.enable_eagerloads(False) in " + "order to " + "proceed with query.yield_per()." % message) + @_generative() def with_labels(self): """Apply column labels to the return value of Query.statement. @@ -428,129 +641,289 @@ class Query(object): When the `Query` actually issues SQL to load rows, it always uses column labeling. + .. note:: The :meth:`.Query.with_labels` method *only* applies + the output of :attr:`.Query.statement`, and *not* to any of + the result-row invoking systems of :class:`.Query` itself, e.g. + :meth:`.Query.first`, :meth:`.Query.all`, etc. To execute + a query using :meth:`.Query.with_labels`, invoke the + :attr:`.Query.statement` using :meth:`.Session.execute`:: + + result = session.execute(query.with_labels().statement) + + """ self._with_labels = True - + @_generative() def enable_assertions(self, value): """Control whether assertions are generated. - - When set to False, the returned Query will - not assert its state before certain operations, + + When set to False, the returned Query will + not assert its state before certain operations, including that LIMIT/OFFSET has not been applied when filter() is called, no criterion exists when get() is called, and no "from_statement()" exists when filter()/order_by()/group_by() etc. - is called. This more permissive mode is used by - custom Query subclasses to specify criterion or + is called. This more permissive mode is used by + custom Query subclasses to specify criterion or other modifiers outside of the usual usage patterns. - - Care should be taken to ensure that the usage + + Care should be taken to ensure that the usage pattern is even possible. A statement applied by from_statement() will override any criterion set by filter() or order_by(), for example. - + """ self._enable_assertions = value - + @property def whereclause(self): - """The WHERE criterion for this Query.""" + """A readonly attribute which returns the current WHERE criterion for + this Query. + + This returned value is a SQL expression construct, or ``None`` if no + criterion has been established. + + """ return self._criterion @_generative() def _with_current_path(self, path): - """indicate that this query applies to objects loaded within a certain path. + """indicate that this query applies to objects loaded + within a certain path. - Used by deferred loaders (see strategies.py) which transfer query - options from an originating query to a newly generated query intended - for the deferred load. + Used by deferred loaders (see strategies.py) which transfer + query options from an originating query to a newly generated + query intended for the deferred load. """ self._current_path = path @_generative(_no_clauseelement_condition) - def with_polymorphic(self, cls_or_mappers, selectable=None, discriminator=None): - """Load columns for descendant mappers of this Query's mapper. + def with_polymorphic(self, + cls_or_mappers, + selectable=None, + polymorphic_on=None): + """Load columns for inheriting classes. - Using this method will ensure that each descendant mapper's - tables are included in the FROM clause, and will allow filter() - criterion to be used against those tables. The resulting - instances will also have those columns already loaded so that - no "post fetch" of those columns will be required. + :meth:`.Query.with_polymorphic` applies transformations + to the "main" mapped class represented by this :class:`.Query`. + The "main" mapped class here means the :class:`.Query` + object's first argument is a full class, i.e. + ``session.query(SomeClass)``. These transformations allow additional + tables to be present in the FROM clause so that columns for a + joined-inheritance subclass are available in the query, both for the + purposes of load-time efficiency as well as the ability to use + these columns at query time. - :param cls_or_mappers: a single class or mapper, or list of class/mappers, - which inherit from this Query's mapper. Alternatively, it - may also be the string ``'*'``, in which case all descending - mappers will be added to the FROM clause. + See the documentation section :ref:`with_polymorphic` for + details on how this method is used. - :param selectable: a table or select() statement that will - be used in place of the generated FROM clause. This argument - is required if any of the desired mappers use concrete table - inheritance, since SQLAlchemy currently cannot generate UNIONs - among tables automatically. If used, the ``selectable`` - argument must represent the full set of tables and columns mapped - by every desired mapper. Otherwise, the unaccounted mapped columns - will result in their table being appended directly to the FROM - clause which will usually lead to incorrect results. - - :param discriminator: a column to be used as the "discriminator" - column for the given selectable. If not given, the polymorphic_on - attribute of the mapper will be used, if any. This is useful - for mappers that don't have polymorphic loading behavior by default, - such as concrete table mappers. + .. versionchanged:: 0.8 + A new and more flexible function + :func:`.orm.with_polymorphic` supersedes + :meth:`.Query.with_polymorphic`, as it can apply the equivalent + functionality to any set of columns or classes in the + :class:`.Query`, not just the "zero mapper". See that + function for a description of arguments. """ - entity = self._generate_mapper_zero() - entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable, discriminator=discriminator) + + if not self._primary_entity: + raise sa_exc.InvalidRequestError( + "No primary mapper set up for this Query.") + entity = self._entities[0]._clone() + self._entities = [entity] + self._entities[1:] + entity.set_with_polymorphic(self, + cls_or_mappers, + selectable=selectable, + polymorphic_on=polymorphic_on) @_generative() def yield_per(self, count): - """Yield only ``count`` rows at a time. + r"""Yield only ``count`` rows at a time. - WARNING: use this method with caution; if the same instance is present - in more than one batch of rows, end-user changes to attributes will be - overwritten. + The purpose of this method is when fetching very large result sets + (> 10K rows), to batch results in sub-collections and yield them + out partially, so that the Python interpreter doesn't need to declare + very large areas of memory which is both time consuming and leads + to excessive memory use. The performance from fetching hundreds of + thousands of rows can often double when a suitable yield-per setting + (e.g. approximately 1000) is used, even with DBAPIs that buffer + rows (which are most). - In particular, it's usually impossible to use this setting with - eagerly loaded collections (i.e. any lazy='joined' or 'subquery') - since those collections will be cleared for a new load when - encountered in a subsequent result batch. In the case of 'subquery' - loading, the full result for all rows is fetched which generally - defeats the purpose of :meth:`~sqlalchemy.orm.query.Query.yield_per`. + The :meth:`.Query.yield_per` method **is not compatible with most + eager loading schemes, including subqueryload and joinedload with + collections**. For this reason, it may be helpful to disable + eager loads, either unconditionally with + :meth:`.Query.enable_eagerloads`:: + + q = sess.query(Object).yield_per(100).enable_eagerloads(False) + + Or more selectively using :func:`.lazyload`; such as with + an asterisk to specify the default loader scheme:: + + q = sess.query(Object).yield_per(100).\ + options(lazyload('*'), joinedload(Object.some_related)) + + .. warning:: + + Use this method with caution; if the same instance is + present in more than one batch of rows, end-user changes + to attributes will be overwritten. + + In particular, it's usually impossible to use this setting + with eagerly loaded collections (i.e. any lazy='joined' or + 'subquery') since those collections will be cleared for a + new load when encountered in a subsequent result batch. + In the case of 'subquery' loading, the full result for all + rows is fetched which generally defeats the purpose of + :meth:`~sqlalchemy.orm.query.Query.yield_per`. + + Also note that while + :meth:`~sqlalchemy.orm.query.Query.yield_per` will set the + ``stream_results`` execution option to True, currently + this is only understood by + :mod:`~sqlalchemy.dialects.postgresql.psycopg2`, + :mod:`~sqlalchemy.dialects.mysql.mysqldb` and + :mod:`~sqlalchemy.dialects.mysql.pymysql` dialects + which will stream results using server side cursors + instead of pre-buffer all rows for this query. Other + DBAPIs **pre-buffer all rows** before making them + available. The memory use of raw database rows is much less + than that of an ORM-mapped object, but should still be taken into + consideration when benchmarking. + + .. seealso:: + + :meth:`.Query.enable_eagerloads` - Also note that many DBAPIs do not "stream" results, pre-buffering - all rows before making them available, including mysql-python and - psycopg2. :meth:`~sqlalchemy.orm.query.Query.yield_per` will also - set the ``stream_results`` execution - option to ``True``, which currently is only understood by psycopg2 - and causes server side cursors to be used. - """ self._yield_per = count - self._execution_options = self._execution_options.copy() - self._execution_options['stream_results'] = True - - def get(self, ident): - """Return an instance of the object based on the given identifier, or None if not found. + self._execution_options = self._execution_options.union( + {"stream_results": True, + "max_row_buffer": count}) - The `ident` argument is a scalar or tuple of primary key column values - in the order of the table def's primary key columns. + def get(self, ident): + """Return an instance based on the given primary key identifier, + or ``None`` if not found. + + E.g.:: + + my_user = session.query(User).get(5) + + some_object = session.query(VersionedFoo).get((5, 10)) + + :meth:`~.Query.get` is special in that it provides direct + access to the identity map of the owning :class:`.Session`. + If the given primary key identifier is present + in the local identity map, the object is returned + directly from this collection and no SQL is emitted, + unless the object has been marked fully expired. + If not present, + a SELECT is performed in order to locate the object. + + :meth:`~.Query.get` also will perform a check if + the object is present in the identity map and + marked as expired - a SELECT + is emitted to refresh the object as well as to + ensure that the row is still present. + If not, :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + :meth:`~.Query.get` is only used to return a single + mapped instance, not multiple instances or + individual column constructs, and strictly + on a single primary key value. The originating + :class:`.Query` must be constructed in this way, + i.e. against a single mapped entity, + with no additional filtering criterion. Loading + options via :meth:`~.Query.options` may be applied + however, and will be used if the object is not + yet locally present. + + A lazy-loading, many-to-one attribute configured + by :func:`.relationship`, using a simple + foreign-key-to-primary-key criterion, will also use an + operation equivalent to :meth:`~.Query.get` in order to retrieve + the target value from the local identity map + before querying the database. See :doc:`/orm/loading_relationships` + for further details on relationship loading. + + :param ident: A scalar or tuple value representing + the primary key. For a composite primary key, + the order of identifiers corresponds in most cases + to that of the mapped :class:`.Table` object's + primary key columns. For a :func:`.mapper` that + was given the ``primary key`` argument during + construction, the order of identifiers corresponds + to the elements present in this collection. + + :return: The object instance, or ``None``. """ + return self._get_impl(ident, loading.load_on_ident) + def _get_impl(self, ident, fallback_fn): # convert composite types to individual args if hasattr(ident, '__composite_values__'): ident = ident.__composite_values__() - key = self._only_mapper_zero( - "get() can only be used against a single mapped class." - ).identity_key_from_primary_key(ident) - return self._get(key, ident) + ident = util.to_list(ident) + + mapper = self._only_full_mapper_zero("get") + + if len(ident) != len(mapper.primary_key): + raise sa_exc.InvalidRequestError( + "Incorrect number of values in identifier to formulate " + "primary key for query.get(); primary key columns are %s" % + ','.join("'%s'" % c for c in mapper.primary_key)) + + key = mapper.identity_key_from_primary_key(ident) + + if not self._populate_existing and \ + not mapper.always_refresh and \ + self._for_update_arg is None: + + instance = loading.get_from_identity( + self.session, key, attributes.PASSIVE_OFF) + if instance is not None: + self._get_existing_condition() + # reject calls for id in identity map but class + # mismatch. + if not issubclass(instance.__class__, mapper.class_): + return None + return instance + + return fallback_fn(self, key) @_generative() def correlate(self, *args): - self._correlate = self._correlate.union(_orm_selectable(s) for s in args) + """Return a :class:`.Query` construct which will correlate the given + FROM clauses to that of an enclosing :class:`.Query` or + :func:`~.expression.select`. + + The method here accepts mapped classes, :func:`.aliased` constructs, + and :func:`.mapper` constructs as arguments, which are resolved into + expression constructs, in addition to appropriate expression + constructs. + + The correlation arguments are ultimately passed to + :meth:`.Select.correlate` after coercion to expression constructs. + + The correlation arguments take effect in such cases + as when :meth:`.Query.from_self` is used, or when + a subquery as returned by :meth:`.Query.subquery` is + embedded in another :func:`~.expression.select` construct. + + """ + + for s in args: + if s is None: + self._correlate = self._correlate.union([None]) + else: + self._correlate = self._correlate.union( + sql_util.surface_selectables(_interpret_as_from(s)) + ) @_generative() def autoflush(self, setting): @@ -566,94 +939,305 @@ class Query(object): @_generative() def populate_existing(self): - """Return a Query that will refresh all instances loaded. + """Return a :class:`.Query` that will expire and refresh all instances + as they are loaded, or reused from the current :class:`.Session`. - This includes all entities accessed from the database, including - secondary entities, eagerly-loaded collection items. - - All changes present on entities which are already present in the - session will be reset and the entities will all be marked "clean". - - An alternative to populate_existing() is to expire the Session - fully using session.expire_all(). + :meth:`.populate_existing` does not improve behavior when + the ORM is used normally - the :class:`.Session` object's usual + behavior of maintaining a transaction and expiring all attributes + after rollback or commit handles object state automatically. + This method is not intended for general use. """ self._populate_existing = True - def with_parent(self, instance, property=None): - """Add a join criterion corresponding to a relationship to the given - parent instance. + @_generative() + def _with_invoke_all_eagers(self, value): + """Set the 'invoke all eagers' flag which causes joined- and + subquery loaders to traverse into already-loaded related objects + and collections. - instance - a persistent or detached instance which is related to class - represented by this query. - - property - string name of the property which relates this query's class to the - instance. if None, the method will attempt to find a suitable - property. - - Currently, this method only works with immediate parent relationships, - but in the future may be enhanced to work across a chain of parent - mappers. + Default is that of :attr:`.Query._invoke_all_eagers`. """ - from sqlalchemy.orm import properties - mapper = object_mapper(instance) + self._invoke_all_eagers = value + + def with_parent(self, instance, property=None): + """Add filtering criterion that relates the given instance + to a child object or collection, using its attribute state + as well as an established :func:`.relationship()` + configuration. + + The method uses the :func:`.with_parent` function to generate + the clause, the result of which is passed to :meth:`.Query.filter`. + + Parameters are the same as :func:`.with_parent`, with the exception + that the given property can be None, in which case a search is + performed against this :class:`.Query` object's target mapper. + + """ + if property is None: + mapper_zero = self._mapper_zero() + + mapper = object_mapper(instance) + for prop in mapper.iterate_properties: - if isinstance(prop, properties.PropertyLoader) and prop.mapper is self._mapper_zero(): + if isinstance(prop, properties.RelationshipProperty) and \ + prop.mapper is mapper_zero: + property = prop break else: raise sa_exc.InvalidRequestError( - "Could not locate a property which relates instances " - "of class '%s' to instances of class '%s'" % - (self._mapper_zero().class_.__name__, instance.__class__.__name__) - ) - else: - prop = mapper.get_property(property, resolve_synonyms=True) - return self.filter(prop.compare(operators.eq, instance, value_is_parent=True)) + "Could not locate a property which relates instances " + "of class '%s' to instances of class '%s'" % + ( + self._mapper_zero().class_.__name__, + instance.__class__.__name__) + ) + + return self.filter(with_parent(instance, property)) @_generative() def add_entity(self, entity, alias=None): - """add a mapped entity to the list of result columns to be returned.""" + """add a mapped entity to the list of result columns + to be returned.""" if alias is not None: entity = aliased(entity, alias) self._entities = list(self._entities) m = _MapperEntity(self, entity) - self._setup_aliasizers([m]) + self._set_entity_selectables([m]) + + @_generative() + def with_session(self, session): + """Return a :class:`.Query` that will use the given :class:`.Session`. + + While the :class:`.Query` object is normally instantiated using the + :meth:`.Session.query` method, it is legal to build the :class:`.Query` + directly without necessarily using a :class:`.Session`. Such a + :class:`.Query` object, or any :class:`.Query` already associated + with a different :class:`.Session`, can produce a new :class:`.Query` + object associated with a target session using this method:: + + from sqlalchemy.orm import Query + + query = Query([MyClass]).filter(MyClass.id == 5) + + result = query.with_session(my_session).one() + + """ + + self.session = session def from_self(self, *entities): - """return a Query that selects from this Query's SELECT statement. + r"""return a Query that selects from this Query's + SELECT statement. - \*entities - optional list of entities which will replace - those being selected. + :meth:`.Query.from_self` essentially turns the SELECT statement + into a SELECT of itself. Given a query such as:: + + q = session.query(User).filter(User.name.like('e%')) + + Given the :meth:`.Query.from_self` version:: + + q = session.query(User).filter(User.name.like('e%')).from_self() + + This query renders as: + + .. sourcecode:: sql + + SELECT anon_1.user_id AS anon_1_user_id, + anon_1.user_name AS anon_1_user_name + FROM (SELECT "user".id AS user_id, "user".name AS user_name + FROM "user" + WHERE "user".name LIKE :name_1) AS anon_1 + + There are lots of cases where :meth:`.Query.from_self` may be useful. + A simple one is where above, we may want to apply a row LIMIT to + the set of user objects we query against, and then apply additional + joins against that row-limited set:: + + q = session.query(User).filter(User.name.like('e%')).\ + limit(5).from_self().\ + join(User.addresses).filter(Address.email.like('q%')) + + The above query joins to the ``Address`` entity but only against the + first five results of the ``User`` query: + + .. sourcecode:: sql + + SELECT anon_1.user_id AS anon_1_user_id, + anon_1.user_name AS anon_1_user_name + FROM (SELECT "user".id AS user_id, "user".name AS user_name + FROM "user" + WHERE "user".name LIKE :name_1 + LIMIT :param_1) AS anon_1 + JOIN address ON anon_1.user_id = address.user_id + WHERE address.email LIKE :email_1 + + **Automatic Aliasing** + + Another key behavior of :meth:`.Query.from_self` is that it applies + **automatic aliasing** to the entities inside the subquery, when + they are referenced on the outside. Above, if we continue to + refer to the ``User`` entity without any additional aliasing applied + to it, those references wil be in terms of the subquery:: + + q = session.query(User).filter(User.name.like('e%')).\ + limit(5).from_self().\ + join(User.addresses).filter(Address.email.like('q%')).\ + order_by(User.name) + + The ORDER BY against ``User.name`` is aliased to be in terms of the + inner subquery: + + .. sourcecode:: sql + + SELECT anon_1.user_id AS anon_1_user_id, + anon_1.user_name AS anon_1_user_name + FROM (SELECT "user".id AS user_id, "user".name AS user_name + FROM "user" + WHERE "user".name LIKE :name_1 + LIMIT :param_1) AS anon_1 + JOIN address ON anon_1.user_id = address.user_id + WHERE address.email LIKE :email_1 ORDER BY anon_1.user_name + + The automatic aliasing feature only works in a **limited** way, + for simple filters and orderings. More ambitious constructions + such as referring to the entity in joins should prefer to use + explicit subquery objects, typically making use of the + :meth:`.Query.subquery` method to produce an explicit subquery object. + Always test the structure of queries by viewing the SQL to ensure + a particular structure does what's expected! + + **Changing the Entities** + + :meth:`.Query.from_self` also includes the ability to modify what + columns are being queried. In our example, we want ``User.id`` + to be queried by the inner query, so that we can join to the + ``Address`` entity on the outside, but we only wanted the outer + query to return the ``Address.email`` column:: + + q = session.query(User).filter(User.name.like('e%')).\ + limit(5).from_self(Address.email).\ + join(User.addresses).filter(Address.email.like('q%')) + + yielding: + + .. sourcecode:: sql + + SELECT address.email AS address_email + FROM (SELECT "user".id AS user_id, "user".name AS user_name + FROM "user" + WHERE "user".name LIKE :name_1 + LIMIT :param_1) AS anon_1 + JOIN address ON anon_1.user_id = address.user_id + WHERE address.email LIKE :email_1 + + **Looking out for Inner / Outer Columns** + + Keep in mind that when referring to columns that originate from + inside the subquery, we need to ensure they are present in the + columns clause of the subquery itself; this is an ordinary aspect of + SQL. For example, if we wanted to load from a joined entity inside + the subquery using :func:`.contains_eager`, we need to add those + columns. Below illustrates a join of ``Address`` to ``User``, + then a subquery, and then we'd like :func:`.contains_eager` to access + the ``User`` columns:: + + q = session.query(Address).join(Address.user).\ + filter(User.name.like('e%')) + + q = q.add_entity(User).from_self().\ + options(contains_eager(Address.user)) + + We use :meth:`.Query.add_entity` above **before** we call + :meth:`.Query.from_self` so that the ``User`` columns are present + in the inner subquery, so that they are available to the + :func:`.contains_eager` modifier we are using on the outside, + producing: + + .. sourcecode:: sql + + SELECT anon_1.address_id AS anon_1_address_id, + anon_1.address_email AS anon_1_address_email, + anon_1.address_user_id AS anon_1_address_user_id, + anon_1.user_id AS anon_1_user_id, + anon_1.user_name AS anon_1_user_name + FROM ( + SELECT address.id AS address_id, + address.email AS address_email, + address.user_id AS address_user_id, + "user".id AS user_id, + "user".name AS user_name + FROM address JOIN "user" ON "user".id = address.user_id + WHERE "user".name LIKE :name_1) AS anon_1 + + If we didn't call ``add_entity(User)``, but still asked + :func:`.contains_eager` to load the ``User`` entity, it would be + forced to add the table on the outside without the correct + join criteria - note the ``anon1, "user"`` phrase at + the end: + + .. sourcecode:: sql + + -- incorrect query + SELECT anon_1.address_id AS anon_1_address_id, + anon_1.address_email AS anon_1_address_email, + anon_1.address_user_id AS anon_1_address_user_id, + "user".id AS user_id, + "user".name AS user_name + FROM ( + SELECT address.id AS address_id, + address.email AS address_email, + address.user_id AS address_user_id + FROM address JOIN "user" ON "user".id = address.user_id + WHERE "user".name LIKE :name_1) AS anon_1, "user" + + :param \*entities: optional list of entities which will replace + those being selected. """ fromclause = self.with_labels().enable_eagerloads(False).\ - statement.correlate(None) + statement.correlate(None) q = self._from_selectable(fromclause) + q._enable_single_crit = False + q._select_from_entity = self._entity_zero() if entities: q._set_entities(entities) return q - + + @_generative() + def _set_enable_single_crit(self, val): + self._enable_single_crit = val + @_generative() def _from_selectable(self, fromclause): - for attr in ('_statement', '_criterion', '_order_by', '_group_by', - '_limit', '_offset', '_joinpath', '_joinpoint', - '_distinct' + for attr in ( + '_statement', '_criterion', + '_order_by', '_group_by', + '_limit', '_offset', + '_joinpath', '_joinpoint', + '_distinct', '_having', + '_prefixes', '_suffixes' ): self.__dict__.pop(attr, None) - self._set_select_from(fromclause) + self._set_select_from([fromclause], True) + + # this enables clause adaptation for non-ORM + # expressions. + self._orm_only_from_obj_alias = False + old_entities = self._entities self._entities = [] for e in old_entities: e.adapt_to_selectable(self, self._from_obj[0]) def values(self, *columns): - """Return an iterator yielding result tuples corresponding to the given list of columns""" + """Return an iterator yielding result tuples corresponding + to the given list of columns""" if not columns: return iter(()) @@ -665,19 +1249,44 @@ class Query(object): _values = values def value(self, column): - """Return a scalar result corresponding to the given column expression.""" + """Return a scalar result corresponding to the given + column expression.""" try: - # Py3K - #return self.values(column).__next__()[0] - # Py2K - return self.values(column).next()[0] - # end Py2K + return next(self.values(column))[0] except StopIteration: return None + @_generative() + def with_entities(self, *entities): + """Return a new :class:`.Query` replacing the SELECT list with the + given entities. + + e.g.:: + + # Users, filtered on some arbitrary criterion + # and then ordered by related email address + q = session.query(User).\ + join(User.address).\ + filter(User.name.like('%ed%')).\ + order_by(Address.email) + + # given *only* User.id==5, Address.email, and 'q', what + # would the *next* User in the result be ? + subq = q.with_entities(Address.email).\ + order_by(None).\ + filter(User.id==5).\ + subquery() + q = q.join((subq, subq.c.email < Address.email)).\ + limit(1) + + .. versionadded:: 0.6.5 + + """ + self._set_entities(entities) + @_generative() def add_columns(self, *column): - """Add one or more column expressions to the list + """Add one or more column expressions to the list of result columns to be returned.""" self._entities = list(self._entities) @@ -686,18 +1295,30 @@ class Query(object): _ColumnEntity(self, c) # _ColumnEntity may add many entities if the # given arg is a FROM clause - self._setup_aliasizers(self._entities[l:]) + self._set_entity_selectables(self._entities[l:]) - @util.pending_deprecation("add_column() superceded by add_columns()") + @util.pending_deprecation("0.7", + ":meth:`.add_column` is superseded " + "by :meth:`.add_columns`", + False) def add_column(self, column): - """Add a column expression to the list of result columns - to be returned.""" - + """Add a column expression to the list of result columns to be + returned. + + Pending deprecation: :meth:`.add_column` will be superseded by + :meth:`.add_columns`. + + """ return self.add_columns(column) def options(self, *args): """Return a new Query object, applying the given list of - MapperOptions. + mapper options. + + Most supplied options regard changing how column- and + relationship-mapped attributes are loaded. See the sections + :ref:`deferred` and :doc:`/orm/loading_relationships` for reference + documentation. """ return self._options(False, *args) @@ -719,111 +1340,279 @@ class Query(object): for opt in opts: opt.process_query(self) - @_generative() - def with_hint(self, selectable, text, dialect_name=None): - """Add an indexing hint for the given entity or selectable to - this :class:`Query`. - - Functionality is passed straight through to - :meth:`~sqlalchemy.sql.expression.Select.with_hint`, - with the addition that ``selectable`` can be a - :class:`Table`, :class:`Alias`, or ORM entity / mapped class - /etc. + def with_transformation(self, fn): + """Return a new :class:`.Query` object transformed by + the given function. + + E.g.:: + + def filter_something(criterion): + def transform(q): + return q.filter(criterion) + return transform + + q = q.with_transformation(filter_something(x==5)) + + This allows ad-hoc recipes to be created for :class:`.Query` + objects. See the example at :ref:`hybrid_transformers`. + + .. versionadded:: 0.7.4 + """ - mapper, selectable, is_aliased_class = _entity_info(selectable) - + return fn(self) + + @_generative() + def with_hint(self, selectable, text, dialect_name='*'): + """Add an indexing or other executional context + hint for the given entity or selectable to + this :class:`.Query`. + + Functionality is passed straight through to + :meth:`~sqlalchemy.sql.expression.Select.with_hint`, + with the addition that ``selectable`` can be a + :class:`.Table`, :class:`.Alias`, or ORM entity / mapped class + /etc. + + .. seealso:: + + :meth:`.Query.with_statement_hint` + + """ + if selectable is not None: + selectable = inspect(selectable).selectable + self._with_hints += ((selectable, text, dialect_name),) - + + def with_statement_hint(self, text, dialect_name='*'): + """add a statement hint to this :class:`.Select`. + + This method is similar to :meth:`.Select.with_hint` except that + it does not require an individual table, and instead applies to the + statement as a whole. + + This feature calls down into :meth:`.Select.with_statement_hint`. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`.Query.with_hint` + + """ + return self.with_hint(None, text, dialect_name) + @_generative() def execution_options(self, **kwargs): """ Set non-SQL options which take effect during execution. - - The options are the same as those accepted by - :meth:`sqlalchemy.sql.expression.Executable.execution_options`. - + + The options are the same as those accepted by + :meth:`.Connection.execution_options`. + Note that the ``stream_results`` execution option is enabled automatically if the :meth:`~sqlalchemy.orm.query.Query.yield_per()` method is used. """ - _execution_options = self._execution_options.copy() - for key, value in kwargs.items(): - _execution_options[key] = value - self._execution_options = _execution_options + self._execution_options = self._execution_options.union(kwargs) @_generative() def with_lockmode(self, mode): - """Return a new Query object with the specified locking mode.""" + """Return a new :class:`.Query` object with the specified "locking mode", + which essentially refers to the ``FOR UPDATE`` clause. - self._lockmode = mode + .. deprecated:: 0.9.0 superseded by :meth:`.Query.with_for_update`. + + :param mode: a string representing the desired locking mode. + Valid values are: + + * ``None`` - translates to no lockmode + + * ``'update'`` - translates to ``FOR UPDATE`` + (standard SQL, supported by most dialects) + + * ``'update_nowait'`` - translates to ``FOR UPDATE NOWAIT`` + (supported by Oracle, PostgreSQL 8.1 upwards) + + * ``'read'`` - translates to ``LOCK IN SHARE MODE`` (for MySQL), + and ``FOR SHARE`` (for PostgreSQL) + + .. seealso:: + + :meth:`.Query.with_for_update` - improved API for + specifying the ``FOR UPDATE`` clause. + + """ + self._for_update_arg = LockmodeArg.parse_legacy_query(mode) + + @_generative() + def with_for_update(self, read=False, nowait=False, of=None, + skip_locked=False, key_share=False): + """return a new :class:`.Query` with the specified options for the + ``FOR UPDATE`` clause. + + The behavior of this method is identical to that of + :meth:`.SelectBase.with_for_update`. When called with no arguments, + the resulting ``SELECT`` statement will have a ``FOR UPDATE`` clause + appended. When additional arguments are specified, backend-specific + options such as ``FOR UPDATE NOWAIT`` or ``LOCK IN SHARE MODE`` + can take effect. + + E.g.:: + + q = sess.query(User).with_for_update(nowait=True, of=User) + + The above query on a PostgreSQL backend will render like:: + + SELECT users.id AS users_id FROM users FOR UPDATE OF users NOWAIT + + .. versionadded:: 0.9.0 :meth:`.Query.with_for_update` supersedes + the :meth:`.Query.with_lockmode` method. + + .. seealso:: + + :meth:`.GenerativeSelect.with_for_update` - Core level method with + full argument and behavioral description. + + """ + self._for_update_arg = LockmodeArg(read=read, nowait=nowait, of=of, + skip_locked=skip_locked, + key_share=key_share) @_generative() def params(self, *args, **kwargs): - """add values for bind parameters which may have been specified in filter(). + r"""add values for bind parameters which may have been + specified in filter(). - parameters may be specified using \**kwargs, or optionally a single dictionary - as the first positional argument. The reason for both is that \**kwargs is - convenient, however some parameter dictionaries contain unicode keys in which case - \**kwargs cannot be used. + parameters may be specified using \**kwargs, or optionally a single + dictionary as the first positional argument. The reason for both is + that \**kwargs is convenient, however some parameter dictionaries + contain unicode keys in which case \**kwargs cannot be used. """ if len(args) == 1: kwargs.update(args[0]) elif len(args) > 0: - raise sa_exc.ArgumentError("params() takes zero or one positional argument, which is a dictionary.") + raise sa_exc.ArgumentError( + "params() takes zero or one positional argument, " + "which is a dictionary.") self._params = self._params.copy() self._params.update(kwargs) @_generative(_no_statement_condition, _no_limit_offset) - def filter(self, criterion): - """apply the given filtering criterion to the query and return the newly resulting ``Query`` + def filter(self, *criterion): + r"""apply the given filtering criterion to a copy + of this :class:`.Query`, using SQL expressions. - the criterion is any sql.ClauseElement applicable to the WHERE clause of a select. + e.g.:: + + session.query(MyClass).filter(MyClass.name == 'some name') + + Multiple criteria may be specified as comma separated; the effect + is that they will be joined together using the :func:`.and_` + function:: + + session.query(MyClass).\ + filter(MyClass.name == 'some name', MyClass.id > 5) + + The criterion is any SQL expression object applicable to the + WHERE clause of a select. String expressions are coerced + into SQL expression constructs via the :func:`.text` construct. + + .. seealso:: + + :meth:`.Query.filter_by` - filter on keyword expressions. """ - if isinstance(criterion, basestring): - criterion = sql.text(criterion) + for criterion in list(criterion): + criterion = expression._expression_literal_as_text(criterion) - if criterion is not None and not isinstance(criterion, sql.ClauseElement): - raise sa_exc.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string") + criterion = self._adapt_clause(criterion, True, True) - criterion = self._adapt_clause(criterion, True, True) - - if self._criterion is not None: - self._criterion = self._criterion & criterion - else: - self._criterion = criterion + if self._criterion is not None: + self._criterion = self._criterion & criterion + else: + self._criterion = criterion def filter_by(self, **kwargs): - """apply the given filtering criterion to the query and return the newly resulting ``Query``.""" + r"""apply the given filtering criterion to a copy + of this :class:`.Query`, using keyword expressions. - clauses = [_entity_descriptor(self._joinpoint_zero(), key)[0] == value - for key, value in kwargs.iteritems()] + e.g.:: + session.query(MyClass).filter_by(name = 'some name') + + Multiple criteria may be specified as comma separated; the effect + is that they will be joined together using the :func:`.and_` + function:: + + session.query(MyClass).\ + filter_by(name = 'some name', id = 5) + + The keyword expressions are extracted from the primary + entity of the query, or the last entity that was the + target of a call to :meth:`.Query.join`. + + .. seealso:: + + :meth:`.Query.filter` - filter on SQL expressions. + + """ + + clauses = [_entity_descriptor(self._joinpoint_zero(), key) == value + for key, value in kwargs.items()] return self.filter(sql.and_(*clauses)) @_generative(_no_statement_condition, _no_limit_offset) - @util.accepts_a_list_as_starargs(list_deprecation='deprecated') def order_by(self, *criterion): - """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``""" + """apply one or more ORDER BY criterion to the query and return + the newly resulting ``Query`` - if len(criterion) == 1 and criterion[0] is None: - self._order_by = None + All existing ORDER BY settings can be suppressed by + passing ``None`` - this will suppress any ORDER BY configured + on mappers as well. + + Alternatively, passing False will reset ORDER BY and additionally + re-allow default mapper.order_by to take place. Note mapper.order_by + is deprecated. + + """ + + if len(criterion) == 1: + if criterion[0] is False: + if '_order_by' in self.__dict__: + self._order_by = False + return + if criterion[0] is None: + self._order_by = None + return + + criterion = self._adapt_col_list(criterion) + + if self._order_by is False or self._order_by is None: + self._order_by = criterion else: - criterion = self._adapt_col_list(criterion) - - if self._order_by is False or self._order_by is None: - self._order_by = criterion - else: - self._order_by = self._order_by + criterion + self._order_by = self._order_by + criterion @_generative(_no_statement_condition, _no_limit_offset) - @util.accepts_a_list_as_starargs(list_deprecation='deprecated') def group_by(self, *criterion): - """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``""" + """apply one or more GROUP BY criterion to the query and return + the newly resulting :class:`.Query` + + All existing GROUP BY settings can be suppressed by + passing ``None`` - this will suppress any GROUP BY configured + on mappers as well. + + .. versionadded:: 1.1 GROUP BY can be cancelled by passing None, + in the same way as ORDER BY. + + """ + + if len(criterion) == 1: + if criterion[0] is None: + self._group_by = False + return criterion = list(chain(*[_orm_columns(c) for c in criterion])) - criterion = self._adapt_col_list(criterion) if self._group_by is False: @@ -833,13 +1622,29 @@ class Query(object): @_generative(_no_statement_condition, _no_limit_offset) def having(self, criterion): - """apply a HAVING criterion to the query and return the newly resulting ``Query``.""" + r"""apply a HAVING criterion to the query and return the + newly resulting :class:`.Query`. - if isinstance(criterion, basestring): - criterion = sql.text(criterion) + :meth:`~.Query.having` is used in conjunction with + :meth:`~.Query.group_by`. - if criterion is not None and not isinstance(criterion, sql.ClauseElement): - raise sa_exc.ArgumentError("having() argument must be of type sqlalchemy.sql.ClauseElement or string") + HAVING criterion makes it possible to use filters on aggregate + functions like COUNT, SUM, AVG, MAX, and MIN, eg.:: + + q = session.query(User.id).\ + join(User.addresses).\ + group_by(User.id).\ + having(func.count(Address.id) > 2) + + """ + + criterion = expression._expression_literal_as_text(criterion) + + if criterion is not None and \ + not isinstance(criterion, sql.ClauseElement): + raise sa_exc.ArgumentError( + "having() argument must be of type " + "sqlalchemy.sql.ClauseElement or string") criterion = self._adapt_clause(criterion, True, True) @@ -848,6 +1653,11 @@ class Query(object): else: self._having = criterion + def _set_op(self, expr_fn, *q): + return self._from_selectable( + expr_fn(*([self] + list(q))) + )._set_enable_single_crit(False) + def union(self, *q): """Produce a UNION of this Query against one or more queries. @@ -865,7 +1675,8 @@ class Query(object): will nest on each ``union()``, and produces:: - SELECT * FROM (SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y) UNION SELECT * FROM Z) + SELECT * FROM (SELECT * FROM (SELECT * FROM X UNION + SELECT * FROM y) UNION SELECT * FROM Z) Whereas:: @@ -873,166 +1684,386 @@ class Query(object): produces:: - SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y UNION SELECT * FROM Z) + SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y UNION + SELECT * FROM Z) + + Note that many database backends do not allow ORDER BY to + be rendered on a query called within UNION, EXCEPT, etc. + To disable all ORDER BY clauses including those configured + on mappers, issue ``query.order_by(None)`` - the resulting + :class:`.Query` object will not render ORDER BY within + its SELECT statement. """ - - - return self._from_selectable( - expression.union(*([self]+ list(q)))) + return self._set_op(expression.union, *q) def union_all(self, *q): """Produce a UNION ALL of this Query against one or more queries. - Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See that - method for usage examples. + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. """ - return self._from_selectable( - expression.union_all(*([self]+ list(q))) - ) + return self._set_op(expression.union_all, *q) def intersect(self, *q): """Produce an INTERSECT of this Query against one or more queries. - Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See that - method for usage examples. + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. """ - return self._from_selectable( - expression.intersect(*([self]+ list(q))) - ) + return self._set_op(expression.intersect, *q) def intersect_all(self, *q): """Produce an INTERSECT ALL of this Query against one or more queries. - Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See that - method for usage examples. + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. """ - return self._from_selectable( - expression.intersect_all(*([self]+ list(q))) - ) + return self._set_op(expression.intersect_all, *q) def except_(self, *q): """Produce an EXCEPT of this Query against one or more queries. - Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See that - method for usage examples. + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. """ - return self._from_selectable( - expression.except_(*([self]+ list(q))) - ) + return self._set_op(expression.except_, *q) def except_all(self, *q): """Produce an EXCEPT ALL of this Query against one or more queries. - Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See that - method for usage examples. + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. """ - return self._from_selectable( - expression.except_all(*([self]+ list(q))) - ) + return self._set_op(expression.except_all, *q) - @util.accepts_a_list_as_starargs(list_deprecation='deprecated') def join(self, *props, **kwargs): - """Create a join against this ``Query`` object's criterion - and apply generatively, returning the newly resulting ``Query``. + r"""Create a SQL JOIN against this :class:`.Query` object's criterion + and apply generatively, returning the newly resulting :class:`.Query`. - Each element in \*props may be: + **Simple Relationship Joins** - * a string property name, i.e. "rooms". This will join along the - relationship of the same name from this Query's "primary" mapper, if - one is present. + Consider a mapping between two classes ``User`` and ``Address``, + with a relationship ``User.addresses`` representing a collection + of ``Address`` objects associated with each ``User``. The most + common usage of :meth:`~.Query.join` is to create a JOIN along this + relationship, using the ``User.addresses`` attribute as an indicator + for how this should occur:: - * a class-mapped attribute, i.e. Houses.rooms. This will create a - join from "Houses" table to that of the "rooms" relationship. + q = session.query(User).join(User.addresses) - * a 2-tuple containing a target class or selectable, and an "ON" - clause. The ON clause can be the property name/ attribute like - above, or a SQL expression. + Where above, the call to :meth:`~.Query.join` along ``User.addresses`` + will result in SQL equivalent to:: - e.g.:: + SELECT user.* FROM user JOIN address ON user.id = address.user_id - # join along string attribute names - session.query(Company).join('employees') - session.query(Company).join('employees', 'tasks') + In the above example we refer to ``User.addresses`` as passed to + :meth:`~.Query.join` as the *on clause*, that is, it indicates + how the "ON" portion of the JOIN should be constructed. For a + single-entity query such as the one above (i.e. we start by selecting + only from ``User`` and nothing else), the relationship can also be + specified by its string name:: - # join the Person entity to an alias of itself, - # along the "friends" relationship - PAlias = aliased(Person) - session.query(Person).join((Palias, Person.friends)) + q = session.query(User).join("addresses") - # join from Houses to the "rooms" attribute on the - # "Colonials" subclass of Houses, then join to the - # "closets" relationship on Room - session.query(Houses).join(Colonials.rooms, Room.closets) + :meth:`~.Query.join` can also accommodate multiple + "on clause" arguments to produce a chain of joins, such as below + where a join across four related entities is constructed:: - # join from Company entities to the "employees" collection, - # using "people JOIN engineers" as the target. Then join - # to the "computers" collection on the Engineer entity. - session.query(Company).join((people.join(engineers), 'employees'), Engineer.computers) + q = session.query(User).join("orders", "items", "keywords") - # join from Articles to Keywords, using the "keywords" attribute. - # assume this is a many-to-many relationship. - session.query(Article).join(Article.keywords) + The above would be shorthand for three separate calls to + :meth:`~.Query.join`, each using an explicit attribute to indicate + the source entity:: - # same thing, but spelled out entirely explicitly - # including the association table. - session.query(Article).join( - (article_keywords, Articles.id==article_keywords.c.article_id), - (Keyword, Keyword.id==article_keywords.c.keyword_id) - ) + q = session.query(User).\ + join(User.orders).\ + join(Order.items).\ + join(Item.keywords) - \**kwargs include: + **Joins to a Target Entity or Selectable** - aliased - when joining, create anonymous aliases of each table. This is - used for self-referential joins or multiple joins to the same table. - Consider usage of the aliased(SomeClass) construct as a more explicit - approach to this. + A second form of :meth:`~.Query.join` allows any mapped entity + or core selectable construct as a target. In this usage, + :meth:`~.Query.join` will attempt + to create a JOIN along the natural foreign key relationship between + two entities:: - from_joinpoint - when joins are specified using string property names, - locate the property from the mapper found in the most recent previous - join() call, instead of from the root entity. + q = session.query(User).join(Address) + + The above calling form of :meth:`~.Query.join` will raise an error if + either there are no foreign keys between the two entities, or if + there are multiple foreign key linkages between them. In the + above calling form, :meth:`~.Query.join` is called upon to + create the "on clause" automatically for us. The target can + be any mapped entity or selectable, such as a :class:`.Table`:: + + q = session.query(User).join(addresses_table) + + **Joins to a Target with an ON Clause** + + The third calling form allows both the target entity as well + as the ON clause to be passed explicitly. Suppose for + example we wanted to join to ``Address`` twice, using + an alias the second time. We use :func:`~sqlalchemy.orm.aliased` + to create a distinct alias of ``Address``, and join + to it using the ``target, onclause`` form, so that the + alias can be specified explicitly as the target along with + the relationship to instruct how the ON clause should proceed:: + + a_alias = aliased(Address) + + q = session.query(User).\ + join(User.addresses).\ + join(a_alias, User.addresses).\ + filter(Address.email_address=='ed@foo.com').\ + filter(a_alias.email_address=='ed@bar.com') + + Where above, the generated SQL would be similar to:: + + SELECT user.* FROM user + JOIN address ON user.id = address.user_id + JOIN address AS address_1 ON user.id=address_1.user_id + WHERE address.email_address = :email_address_1 + AND address_1.email_address = :email_address_2 + + The two-argument calling form of :meth:`~.Query.join` + also allows us to construct arbitrary joins with SQL-oriented + "on clause" expressions, not relying upon configured relationships + at all. Any SQL expression can be passed as the ON clause + when using the two-argument form, which should refer to the target + entity in some way as well as an applicable source entity:: + + q = session.query(User).join(Address, User.id==Address.user_id) + + .. versionchanged:: 0.7 + In SQLAlchemy 0.6 and earlier, the two argument form of + :meth:`~.Query.join` requires the usage of a tuple: + ``query(User).join((Address, User.id==Address.user_id))``\ . + This calling form is accepted in 0.7 and further, though + is not necessary unless multiple join conditions are passed to + a single :meth:`~.Query.join` call, which itself is also not + generally necessary as it is now equivalent to multiple + calls (this wasn't always the case). + + **Advanced Join Targeting and Adaption** + + There is a lot of flexibility in what the "target" can be when using + :meth:`~.Query.join`. As noted previously, it also accepts + :class:`.Table` constructs and other selectables such as + :func:`.alias` and :func:`.select` constructs, with either the one + or two-argument forms:: + + addresses_q = select([Address.user_id]).\ + where(Address.email_address.endswith("@bar.com")).\ + alias() + + q = session.query(User).\ + join(addresses_q, addresses_q.c.user_id==User.id) + + :meth:`~.Query.join` also features the ability to *adapt* a + :meth:`~sqlalchemy.orm.relationship` -driven ON clause to the target + selectable. Below we construct a JOIN from ``User`` to a subquery + against ``Address``, allowing the relationship denoted by + ``User.addresses`` to *adapt* itself to the altered target:: + + address_subq = session.query(Address).\ + filter(Address.email_address == 'ed@foo.com').\ + subquery() + + q = session.query(User).join(address_subq, User.addresses) + + Producing SQL similar to:: + + SELECT user.* FROM user + JOIN ( + SELECT address.id AS id, + address.user_id AS user_id, + address.email_address AS email_address + FROM address + WHERE address.email_address = :email_address_1 + ) AS anon_1 ON user.id = anon_1.user_id + + The above form allows one to fall back onto an explicit ON + clause at any time:: + + q = session.query(User).\ + join(address_subq, User.id==address_subq.c.user_id) + + **Controlling what to Join From** + + While :meth:`~.Query.join` exclusively deals with the "right" + side of the JOIN, we can also control the "left" side, in those + cases where it's needed, using :meth:`~.Query.select_from`. + Below we construct a query against ``Address`` but can still + make usage of ``User.addresses`` as our ON clause by instructing + the :class:`.Query` to select first from the ``User`` + entity:: + + q = session.query(Address).select_from(User).\ + join(User.addresses).\ + filter(User.name == 'ed') + + Which will produce SQL similar to:: + + SELECT address.* FROM user + JOIN address ON user.id=address.user_id + WHERE user.name = :name_1 + + **Constructing Aliases Anonymously** + + :meth:`~.Query.join` can construct anonymous aliases + using the ``aliased=True`` flag. This feature is useful + when a query is being joined algorithmically, such as + when querying self-referentially to an arbitrary depth:: + + q = session.query(Node).\ + join("children", "children", aliased=True) + + When ``aliased=True`` is used, the actual "alias" construct + is not explicitly available. To work with it, methods such as + :meth:`.Query.filter` will adapt the incoming entity to + the last join point:: + + q = session.query(Node).\ + join("children", "children", aliased=True).\ + filter(Node.name == 'grandchild 1') + + When using automatic aliasing, the ``from_joinpoint=True`` + argument can allow a multi-node join to be broken into + multiple calls to :meth:`~.Query.join`, so that + each path along the way can be further filtered:: + + q = session.query(Node).\ + join("children", aliased=True).\ + filter(Node.name='child 1').\ + join("children", aliased=True, from_joinpoint=True).\ + filter(Node.name == 'grandchild 1') + + The filtering aliases above can then be reset back to the + original ``Node`` entity using :meth:`~.Query.reset_joinpoint`:: + + q = session.query(Node).\ + join("children", "children", aliased=True).\ + filter(Node.name == 'grandchild 1').\ + reset_joinpoint().\ + filter(Node.name == 'parent 1) + + For an example of ``aliased=True``, see the distribution + example :ref:`examples_xmlpersistence` which illustrates + an XPath-like query system using algorithmic joins. + + :param \*props: A collection of one or more join conditions, + each consisting of a relationship-bound attribute or string + relationship name representing an "on clause", or a single + target entity, or a tuple in the form of ``(target, onclause)``. + A special two-argument calling form of the form ``target, onclause`` + is also accepted. + :param aliased=False: If True, indicate that the JOIN target should be + anonymously aliased. Subsequent calls to :meth:`~.Query.filter` + and similar will adapt the incoming criterion to the target + alias, until :meth:`~.Query.reset_joinpoint` is called. + :param isouter=False: If True, the join used will be a left outer join, + just as if the :meth:`.Query.outerjoin` method were called. This + flag is here to maintain consistency with the same flag as accepted + by :meth:`.FromClause.join` and other Core constructs. + + + .. versionadded:: 1.0.0 + + :param full=False: render FULL OUTER JOIN; implies ``isouter``. + + .. versionadded:: 1.1 + + :param from_joinpoint=False: When using ``aliased=True``, a setting + of True here will cause the join to be from the most recent + joined target, rather than starting back from the original + FROM clauses of the query. + + .. seealso:: + + :ref:`ormtutorial_joins` in the ORM tutorial. + + :ref:`inheritance_toplevel` for details on how + :meth:`~.Query.join` is used for inheritance relationships. + + :func:`.orm.join` - a standalone ORM-level join function, + used internally by :meth:`.Query.join`, which in previous + SQLAlchemy versions was the primary ORM-level joining interface. """ - aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False) + aliased, from_joinpoint, isouter, full = kwargs.pop('aliased', False),\ + kwargs.pop('from_joinpoint', False),\ + kwargs.pop('isouter', False),\ + kwargs.pop('full', False) if kwargs: - raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys())) - return self._join(props, - outerjoin=False, create_aliases=aliased, - from_joinpoint=from_joinpoint) + raise TypeError("unknown arguments: %s" % + ', '.join(sorted(kwargs))) + return self._join(props, + outerjoin=isouter, full=full, + create_aliases=aliased, + from_joinpoint=from_joinpoint) - @util.accepts_a_list_as_starargs(list_deprecation='deprecated') def outerjoin(self, *props, **kwargs): """Create a left outer join against this ``Query`` object's criterion - and apply generatively, retunring the newly resulting ``Query``. + and apply generatively, returning the newly resulting ``Query``. Usage is the same as the ``join()`` method. """ - aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False) + aliased, from_joinpoint, full = kwargs.pop('aliased', False), \ + kwargs.pop('from_joinpoint', False), \ + kwargs.pop('full', False) if kwargs: - raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys())) - return self._join(props, - outerjoin=True, create_aliases=aliased, - from_joinpoint=from_joinpoint) + raise TypeError("unknown arguments: %s" % + ', '.join(sorted(kwargs))) + return self._join(props, + outerjoin=True, full=full, create_aliases=aliased, + from_joinpoint=from_joinpoint) + + def _update_joinpoint(self, jp): + self._joinpoint = jp + # copy backwards to the root of the _joinpath + # dict, so that no existing dict in the path is mutated + while 'prev' in jp: + f, prev = jp['prev'] + prev = prev.copy() + prev[f] = jp + jp['prev'] = (f, prev) + jp = prev + self._joinpath = jp @_generative(_no_statement_condition, _no_limit_offset) - def _join(self, keys, outerjoin, create_aliases, from_joinpoint): - """consumes arguments from join() or outerjoin(), places them into a consistent - format with which to form the actual JOIN constructs. - + def _join(self, keys, outerjoin, full, create_aliases, from_joinpoint): + """consumes arguments from join() or outerjoin(), places them into a + consistent format with which to form the actual JOIN constructs. + """ - self._polymorphic_adapters = self._polymorphic_adapters.copy() if not from_joinpoint: self._reset_joinpoint() - for arg1 in util.to_list(keys): + if len(keys) == 2 and \ + isinstance(keys[0], (expression.FromClause, + type, AliasedClass)) and \ + isinstance(keys[1], (str, expression.ClauseElement, + interfaces.PropComparator)): + # detect 2-arg form of join and + # convert to a tuple. + keys = (keys,) + + keylist = util.to_list(keys) + for idx, arg1 in enumerate(keylist): if isinstance(arg1, tuple): + # "tuple" form of join, multiple + # tuples are accepted as well. The simpler + # "2-arg" form is preferred. May deprecate + # the "tuple" usage. arg1, arg2 = arg1 else: arg2 = None @@ -1041,81 +2072,177 @@ class Query(object): # is a little bit of legacy behavior still at work here # which means they might be in either order. may possibly # lock this down to (right_entity, onclause) in 0.6. - if isinstance(arg1, (interfaces.PropComparator, basestring)): + if isinstance( + arg1, (interfaces.PropComparator, util.string_types)): right_entity, onclause = arg2, arg1 else: right_entity, onclause = arg1, arg2 left_entity = prop = None - - if isinstance(onclause, basestring): + + if isinstance(onclause, interfaces.PropComparator): + of_type = getattr(onclause, '_of_type', None) + else: + of_type = None + + if isinstance(onclause, util.string_types): left_entity = self._joinpoint_zero() - descriptor, prop = _entity_descriptor(left_entity, onclause) + descriptor = _entity_descriptor(left_entity, onclause) onclause = descriptor - + # check for q.join(Class.propname, from_joinpoint=True) # and Class is that of the current joinpoint - elif from_joinpoint and isinstance(onclause, interfaces.PropComparator): - left_entity = onclause.parententity - + elif from_joinpoint and \ + isinstance(onclause, interfaces.PropComparator): + left_entity = onclause._parententity + + info = inspect(self._joinpoint_zero()) left_mapper, left_selectable, left_is_aliased = \ - _entity_info(self._joinpoint_zero()) + getattr(info, 'mapper', None), \ + info.selectable, \ + getattr(info, 'is_aliased_class', None) + if left_mapper is left_entity: left_entity = self._joinpoint_zero() - descriptor, prop = _entity_descriptor(left_entity, onclause.key) + descriptor = _entity_descriptor(left_entity, + onclause.key) onclause = descriptor if isinstance(onclause, interfaces.PropComparator): if right_entity is None: - right_entity = onclause.property.mapper - of_type = getattr(onclause, '_of_type', None) if of_type: right_entity = of_type else: right_entity = onclause.property.mapper - - left_entity = onclause.parententity - + + left_entity = onclause._parententity + prop = onclause.property - if not isinstance(onclause, attributes.QueryableAttribute): + if not isinstance(onclause, attributes.QueryableAttribute): onclause = prop if not create_aliases: # check for this path already present. # don't render in that case. - if (left_entity, right_entity, prop.key) in self._joinpoint: - self._joinpoint = self._joinpoint[(left_entity, right_entity, prop.key)] + edge = (left_entity, right_entity, prop.key) + if edge in self._joinpoint: + # The child's prev reference might be stale -- + # it could point to a parent older than the + # current joinpoint. If this is the case, + # then we need to update it and then fix the + # tree's spine with _update_joinpoint. Copy + # and then mutate the child, which might be + # shared by a different query object. + jp = self._joinpoint[edge].copy() + jp['prev'] = (edge, self._joinpoint) + self._update_joinpoint(jp) + + if idx == len(keylist) - 1: + util.warn( + "Pathed join target %s has already " + "been joined to; skipping" % prop) continue elif onclause is not None and right_entity is None: # TODO: no coverage here raise NotImplementedError("query.join(a==b) not supported.") - - self._join_left_to_right( - left_entity, - right_entity, onclause, - outerjoin, create_aliases, prop) - def _join_left_to_right(self, left, right, onclause, outerjoin, create_aliases, prop): + self._join_left_to_right( + left_entity, + right_entity, onclause, + outerjoin, full, create_aliases, prop) + + def _join_left_to_right(self, left, right, + onclause, outerjoin, full, create_aliases, prop): """append a JOIN to the query's from clause.""" - + + self._polymorphic_adapters = self._polymorphic_adapters.copy() + if left is None: - left = self._joinpoint_zero() + if self._from_obj: + left = self._from_obj[0] + elif self._entities: + left = self._entities[0].entity_zero_or_selectable + + if left is None: + if self._entities: + problem = "Don't know how to join from %s" % self._entities[0] + else: + problem = "No entities to join from" + + raise sa_exc.InvalidRequestError( + "%s; please use " + "select_from() to establish the left " + "entity/selectable of this join" % problem) if left is right and \ not create_aliases: raise sa_exc.InvalidRequestError( - "Can't construct a join from %s to %s, they are the same entity" % - (left, right)) - - left_mapper, left_selectable, left_is_aliased = _entity_info(left) - right_mapper, right_selectable, is_aliased_class = _entity_info(right) + "Can't construct a join from %s to %s, they " + "are the same entity" % + (left, right)) - if right_mapper and prop and not right_mapper.common_parent(prop.mapper): + l_info = inspect(left) + r_info = inspect(right) + + overlap = False + if not create_aliases: + right_mapper = getattr(r_info, "mapper", None) + # if the target is a joined inheritance mapping, + # be more liberal about auto-aliasing. + if right_mapper and ( + right_mapper.with_polymorphic or + isinstance(right_mapper.mapped_table, expression.Join) + ): + for from_obj in self._from_obj or [l_info.selectable]: + if sql_util.selectables_overlap( + l_info.selectable, from_obj) and \ + sql_util.selectables_overlap( + from_obj, r_info.selectable): + overlap = True + break + + if (overlap or not create_aliases) and \ + l_info.selectable is r_info.selectable: raise sa_exc.InvalidRequestError( - "Join target %s does not correspond to " - "the right side of join condition %s" % (right, onclause) + "Can't join table/selectable '%s' to itself" % + l_info.selectable) + + right, onclause = self._prepare_right_side( + r_info, right, onclause, + create_aliases, + prop, overlap) + + # if joining on a MapperProperty path, + # track the path to prevent redundant joins + if not create_aliases and prop: + self._update_joinpoint({ + '_joinpoint_entity': right, + 'prev': ((left, right, prop.key), self._joinpoint) + }) + else: + self._joinpoint = {'_joinpoint_entity': right} + + self._join_to_left(l_info, left, right, onclause, outerjoin, full) + + def _prepare_right_side(self, r_info, right, onclause, create_aliases, + prop, overlap): + info = r_info + + right_mapper, right_selectable, right_is_aliased = \ + getattr(info, 'mapper', None), \ + info.selectable, \ + getattr(info, 'is_aliased_class', False) + + if right_mapper: + self._join_entities += (info, ) + + if right_mapper and prop and \ + not right_mapper.common_parent(prop.mapper): + raise sa_exc.InvalidRequestError( + "Join target %s does not correspond to " + "the right side of join condition %s" % (right, onclause) ) if not right_mapper and prop: @@ -1124,99 +2251,89 @@ class Query(object): need_adapter = False if right_mapper and right is right_selectable: - if not right_selectable.is_derived_from(right_mapper.mapped_table): + if not right_selectable.is_derived_from( + right_mapper.mapped_table): raise sa_exc.InvalidRequestError( "Selectable '%s' is not derived from '%s'" % - (right_selectable.description, right_mapper.mapped_table.description)) + (right_selectable.description, + right_mapper.mapped_table.description)) - if not isinstance(right_selectable, expression.Alias): + if isinstance(right_selectable, expression.SelectBase): + # TODO: this isn't even covered now! right_selectable = right_selectable.alias() + need_adapter = True right = aliased(right_mapper, right_selectable) - need_adapter = True aliased_entity = right_mapper and \ - not is_aliased_class and \ - ( - right_mapper.with_polymorphic or - isinstance(right_mapper.mapped_table, expression.Join) - ) + not right_is_aliased and \ + ( + right_mapper.with_polymorphic and isinstance( + right_mapper._with_polymorphic_selectable, + expression.Alias) + or + overlap # test for overlap: + # orm/inheritance/relationships.py + # SelfReferentialM2MTest + ) if not need_adapter and (create_aliases or aliased_entity): - right = aliased(right) + right = aliased(right, flat=True) need_adapter = True - # if joining on a MapperProperty path, - # track the path to prevent redundant joins - if not create_aliases and prop: - - self._joinpoint = jp = { - '_joinpoint_entity':right, - 'prev':((left, right, prop.key), self._joinpoint) - } - - # copy backwards to the root of the _joinpath - # dict, so that no existing dict in the path is mutated - while 'prev' in jp: - f, prev = jp['prev'] - prev = prev.copy() - prev[f] = jp - jp['prev'] = (f, prev) - jp = prev - - self._joinpath = jp - - else: - self._joinpoint = { - '_joinpoint_entity':right - } - # if an alias() of the right side was generated here, # apply an adapter to all subsequent filter() calls # until reset_joinpoint() is called. if need_adapter: - self._filter_aliases = ORMAdapter(right, - equivalents=right_mapper._equivalent_columns, chain_to=self._filter_aliases) + self._filter_aliases = ORMAdapter( + right, + equivalents=right_mapper and + right_mapper._equivalent_columns or {}, + chain_to=self._filter_aliases) - # if the onclause is a ClauseElement, adapt it with any + # if the onclause is a ClauseElement, adapt it with any # adapters that are in place right now if isinstance(onclause, expression.ClauseElement): onclause = self._adapt_clause(onclause, True, True) - + # if an alias() on the right side was generated, # which is intended to wrap a the right side in a subquery, # ensure that columns retrieved from this target in the result # set are also adapted. - if aliased_entity: - self.__mapper_loads_polymorphically_with( - right_mapper, - ORMAdapter( - right, - equivalents=right_mapper._equivalent_columns - ) - ) - - join_to_left = not is_aliased_class and not left_is_aliased + if aliased_entity and not create_aliases: + self._mapper_loads_polymorphically_with( + right_mapper, + ORMAdapter( + right, + equivalents=right_mapper._equivalent_columns + ) + ) + + return right, onclause + + def _join_to_left(self, l_info, left, right, onclause, outerjoin, full): + info = l_info + left_mapper = getattr(info, 'mapper', None) + left_selectable = info.selectable if self._from_obj: replace_clause_index, clause = sql_util.find_join_source( - self._from_obj, - left_selectable) + self._from_obj, + left_selectable) if clause is not None: - # the entire query's FROM clause is an alias of itself (i.e. from_self(), similar). - # if the left clause is that one, ensure it aliases to the left side. - if self._from_obj_alias and clause is self._from_obj[0]: - join_to_left = True - - clause = orm_join(clause, - right, - onclause, isouter=outerjoin, - join_to_left=join_to_left) + try: + clause = orm_join(clause, + right, + onclause, isouter=outerjoin, full=full) + except sa_exc.ArgumentError as ae: + raise sa_exc.InvalidRequestError( + "Could not find a FROM clause to join from. " + "Tried joining to %s, but got: %s" % (right, ae)) self._from_obj = \ - self._from_obj[:replace_clause_index] + \ - (clause, ) + \ - self._from_obj[replace_clause_index + 1:] + self._from_obj[:replace_clause_index] + \ + (clause, ) + \ + self._from_obj[replace_clause_index + 1:] return if left_mapper: @@ -1227,12 +2344,16 @@ class Query(object): else: clause = left else: - clause = None + clause = left_selectable - if clause is None: - raise sa_exc.InvalidRequestError("Could not find a FROM clause to join from") - - clause = orm_join(clause, right, onclause, isouter=outerjoin, join_to_left=join_to_left) + assert clause is not None + try: + clause = orm_join( + clause, right, onclause, isouter=outerjoin, full=full) + except sa_exc.ArgumentError as ae: + raise sa_exc.InvalidRequestError( + "Could not find a FROM clause to join from. " + "Tried joining to %s, but got: %s" % (right, ae)) self._from_obj = self._from_obj + (clause,) def _reset_joinpoint(self): @@ -1241,51 +2362,182 @@ class Query(object): @_generative(_no_statement_condition) def reset_joinpoint(self): - """return a new Query reset the 'joinpoint' of this Query reset - back to the starting mapper. Subsequent generative calls will - be constructed from the new joinpoint. + """Return a new :class:`.Query`, where the "join point" has + been reset back to the base FROM entities of the query. - Note that each call to join() or outerjoin() also starts from - the root. + This method is usually used in conjunction with the + ``aliased=True`` feature of the :meth:`~.Query.join` + method. See the example in :meth:`~.Query.join` for how + this is used. """ self._reset_joinpoint() @_generative(_no_clauseelement_condition) def select_from(self, *from_obj): - """Set the `from_obj` parameter of the query and return the newly - resulting ``Query``. This replaces the table which this Query selects - from with the given table. - - ``select_from()`` also accepts class arguments. Though usually not necessary, - can ensure that the full selectable of the given mapper is applied, e.g. - for joined-table mappers. + r"""Set the FROM clause of this :class:`.Query` explicitly. + + :meth:`.Query.select_from` is often used in conjunction with + :meth:`.Query.join` in order to control which entity is selected + from on the "left" side of the join. + + The entity or selectable object here effectively replaces the + "left edge" of any calls to :meth:`~.Query.join`, when no + joinpoint is otherwise established - usually, the default "join + point" is the leftmost entity in the :class:`~.Query` object's + list of entities to be selected. + + A typical example:: + + q = session.query(Address).select_from(User).\ + join(User.addresses).\ + filter(User.name == 'ed') + + Which produces SQL equivalent to:: + + SELECT address.* FROM user + JOIN address ON user.id=address.user_id + WHERE user.name = :name_1 + + :param \*from_obj: collection of one or more entities to apply + to the FROM clause. Entities can be mapped classes, + :class:`.AliasedClass` objects, :class:`.Mapper` objects + as well as core :class:`.FromClause` elements like subqueries. + + .. versionchanged:: 0.9 + This method no longer applies the given FROM object + to be the selectable from which matching entities + select from; the :meth:`.select_entity_from` method + now accomplishes this. See that method for a description + of this behavior. + + .. seealso:: + + :meth:`~.Query.join` + + :meth:`.Query.select_entity_from` """ - - obj = [] - for fo in from_obj: - if _is_mapped_class(fo): - mapper, selectable, is_aliased_class = _entity_info(fo) - obj.append(selectable) - elif not isinstance(fo, expression.FromClause): - raise sa_exc.ArgumentError("select_from() accepts FromClause objects only.") - else: - obj.append(fo) - - self._set_select_from(*obj) + + self._set_select_from(from_obj, False) + + @_generative(_no_clauseelement_condition) + def select_entity_from(self, from_obj): + r"""Set the FROM clause of this :class:`.Query` to a + core selectable, applying it as a replacement FROM clause + for corresponding mapped entities. + + The :meth:`.Query.select_entity_from` method supplies an alternative + approach to the use case of applying an :func:`.aliased` construct + explicitly throughout a query. Instead of referring to the + :func:`.aliased` construct explicitly, + :meth:`.Query.select_entity_from` automatically *adapts* all occurences + of the entity to the target selectable. + + Given a case for :func:`.aliased` such as selecting ``User`` + objects from a SELECT statement:: + + select_stmt = select([User]).where(User.id == 7) + user_alias = aliased(User, select_stmt) + + q = session.query(user_alias).\ + filter(user_alias.name == 'ed') + + Above, we apply the ``user_alias`` object explicitly throughout the + query. When it's not feasible for ``user_alias`` to be referenced + explicitly in many places, :meth:`.Query.select_entity_from` may be + used at the start of the query to adapt the existing ``User`` entity:: + + q = session.query(User).\ + select_entity_from(select_stmt).\ + filter(User.name == 'ed') + + Above, the generated SQL will show that the ``User`` entity is + adapted to our statement, even in the case of the WHERE clause: + + .. sourcecode:: sql + + SELECT anon_1.id AS anon_1_id, anon_1.name AS anon_1_name + FROM (SELECT "user".id AS id, "user".name AS name + FROM "user" + WHERE "user".id = :id_1) AS anon_1 + WHERE anon_1.name = :name_1 + + The :meth:`.Query.select_entity_from` method is similar to the + :meth:`.Query.select_from` method, in that it sets the FROM clause + of the query. The difference is that it additionally applies + adaptation to the other parts of the query that refer to the + primary entity. If above we had used :meth:`.Query.select_from` + instead, the SQL generated would have been: + + .. sourcecode:: sql + + -- uses plain select_from(), not select_entity_from() + SELECT "user".id AS user_id, "user".name AS user_name + FROM "user", (SELECT "user".id AS id, "user".name AS name + FROM "user" + WHERE "user".id = :id_1) AS anon_1 + WHERE "user".name = :name_1 + + To supply textual SQL to the :meth:`.Query.select_entity_from` method, + we can make use of the :func:`.text` construct. However, the + :func:`.text` construct needs to be aligned with the columns of our + entity, which is achieved by making use of the + :meth:`.TextClause.columns` method:: + + text_stmt = text("select id, name from user").columns( + User.id, User.name) + q = session.query(User).select_entity_from(text_stmt) + + :meth:`.Query.select_entity_from` itself accepts an :func:`.aliased` + object, so that the special options of :func:`.aliased` such as + :paramref:`.aliased.adapt_on_names` may be used within the + scope of the :meth:`.Query.select_entity_from` method's adaptation + services. Suppose + a view ``user_view`` also returns rows from ``user``. If + we reflect this view into a :class:`.Table`, this view has no + relationship to the :class:`.Table` to which we are mapped, however + we can use name matching to select from it:: + + user_view = Table('user_view', metadata, + autoload_with=engine) + user_view_alias = aliased( + User, user_view, adapt_on_names=True) + q = session.query(User).\ + select_entity_from(user_view_alias).\ + order_by(User.name) + + .. versionchanged:: 1.1.7 The :meth:`.Query.select_entity_from` + method now accepts an :func:`.aliased` object as an alternative + to a :class:`.FromClause` object. + + :param from_obj: a :class:`.FromClause` object that will replace + the FROM clause of this :class:`.Query`. It also may be an instance + of :func:`.aliased`. + + + + .. seealso:: + + :meth:`.Query.select_from` + + """ + + self._set_select_from([from_obj], True) def __getitem__(self, item): if isinstance(item, slice): start, stop, step = util.decode_slice(item) - if isinstance(stop, int) and isinstance(start, int) and stop - start <= 0: + if isinstance(stop, int) and \ + isinstance(start, int) and \ + stop - start <= 0: return [] # perhaps we should execute a count() here so that we # can still use LIMIT/OFFSET ? elif (isinstance(start, int) and start < 0) \ - or (isinstance(stop, int) and stop < 0): + or (isinstance(stop, int) and stop < 0): return list(self)[item] res = self.slice(start, stop) @@ -1294,13 +2546,42 @@ class Query(object): else: return list(res) else: - return list(self[item:item+1])[0] + if item == -1: + return list(self)[-1] + else: + return list(self[item:item + 1])[0] @_generative(_no_statement_condition) def slice(self, start, stop): - """apply LIMIT/OFFSET to the ``Query`` based on a " - "range and return the newly resulting ``Query``.""" - + """Computes the "slice" of the :class:`.Query` represented by + the given indices and returns the resulting :class:`.Query`. + + The start and stop indices behave like the argument to Python's + built-in :func:`range` function. This method provides an + alternative to using ``LIMIT``/``OFFSET`` to get a slice of the + query. + + For example, :: + + session.query(User).order_by(User.id).slice(1, 3) + + renders as + + .. sourcecode:: sql + + SELECT users.id AS users_id, + users.name AS users_name + FROM users ORDER BY users.id + LIMIT ? OFFSET ? + (2, 1) + + .. seealso:: + + :meth:`.Query.limit` + + :meth:`.Query.offset` + + """ if start is not None and stop is not None: self._offset = (self._offset or 0) + start self._limit = stop - start @@ -1309,10 +2590,12 @@ class Query(object): elif start is not None and stop is None: self._offset = (self._offset or 0) + start + if self._offset == 0: + self._offset = None + @_generative(_no_statement_condition) def limit(self, limit): """Apply a ``LIMIT`` to the query and return the newly resulting - ``Query``. """ @@ -1327,12 +2610,89 @@ class Query(object): self._offset = offset @_generative(_no_statement_condition) - def distinct(self): - """Apply a ``DISTINCT`` to the query and return the newly resulting + def distinct(self, *criterion): + r"""Apply a ``DISTINCT`` to the query and return the newly resulting ``Query``. + + .. note:: + + The :meth:`.distinct` call includes logic that will automatically + add columns from the ORDER BY of the query to the columns + clause of the SELECT statement, to satisfy the common need + of the database backend that ORDER BY columns be part of the + SELECT list when DISTINCT is used. These columns *are not* + added to the list of columns actually fetched by the + :class:`.Query`, however, so would not affect results. + The columns are passed through when using the + :attr:`.Query.statement` accessor, however. + + :param \*expr: optional column expressions. When present, + the PostgreSQL dialect will render a ``DISTINCT ON (>)`` + construct. + """ - self._distinct = True + if not criterion: + self._distinct = True + else: + criterion = self._adapt_col_list(criterion) + if isinstance(self._distinct, list): + self._distinct += criterion + else: + self._distinct = criterion + + @_generative() + def prefix_with(self, *prefixes): + r"""Apply the prefixes to the query and return the newly resulting + ``Query``. + + :param \*prefixes: optional prefixes, typically strings, + not using any commas. In particular is useful for MySQL keywords. + + e.g.:: + + query = sess.query(User.name).\ + prefix_with('HIGH_PRIORITY').\ + prefix_with('SQL_SMALL_RESULT', 'ALL') + + Would render:: + + SELECT HIGH_PRIORITY SQL_SMALL_RESULT ALL users.name AS users_name + FROM users + + .. versionadded:: 0.7.7 + + .. seealso:: + + :meth:`.HasPrefixes.prefix_with` + + """ + if self._prefixes: + self._prefixes += prefixes + else: + self._prefixes = prefixes + + @_generative() + def suffix_with(self, *suffixes): + r"""Apply the suffix to the query and return the newly resulting + ``Query``. + + :param \*suffixes: optional suffixes, typically strings, + not using any commas. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`.Query.prefix_with` + + :meth:`.HasSuffixes.suffix_with` + + """ + if self._suffixes: + self._suffixes += suffixes + else: + self._suffixes = suffixes def all(self): """Return the results represented by this ``Query`` as a list. @@ -1349,31 +2709,44 @@ class Query(object): This method bypasses all internal statement compilation, and the statement is executed without modification. - The statement argument is either a string, a ``select()`` construct, - or a ``text()`` construct, and should return the set of columns - appropriate to the entity class represented by this ``Query``. + The statement is typically either a :func:`~.expression.text` + or :func:`~.expression.select` construct, and should return the set + of columns + appropriate to the entity class represented by this :class:`.Query`. - Also see the ``instances()`` method. + .. seealso:: + + :ref:`orm_tutorial_literal_sql` - usage examples in the + ORM tutorial """ - if isinstance(statement, basestring): - statement = sql.text(statement) + statement = expression._expression_literal_as_text(statement) - if not isinstance(statement, (expression._TextClause, expression._SelectBaseMixin)): - raise sa_exc.ArgumentError("from_statement accepts text(), select(), and union() objects only.") + if not isinstance(statement, + (expression.TextClause, + expression.SelectBase)): + raise sa_exc.ArgumentError( + "from_statement accepts text(), select(), " + "and union() objects only.") self._statement = statement def first(self): - """Return the first result of this ``Query`` or + """Return the first result of this ``Query`` or None if the result doesn't contain any row. - + first() applies a limit of one within the generated SQL, so that - only one primary entity row is generated on the server side - (note this may consist of multiple result rows if join-loaded + only one primary entity row is generated on the server side + (note this may consist of multiple result rows if join-loaded collections are present). - Calling ``first()`` results in an execution of the underlying query. + Calling :meth:`.Query.first` results in an execution of the underlying query. + + .. seealso:: + + :meth:`.Query.one` + + :meth:`.Query.one_or_none` """ if self._statement is not None: @@ -1385,36 +2758,67 @@ class Query(object): else: return None - def one(self): - """Return exactly one result or raise an exception. + def one_or_none(self): + """Return at most one result or raise an exception. - Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects - no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` + Returns ``None`` if the query selects + no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` if multiple object identities are returned, or if multiple - rows are returned for a query that does not return object - identities. - - Note that an entity query, that is, one which selects one or - more mapped classes as opposed to individual column attributes, - may ultimately represent many rows but only one row of - unique entity or entities - this is a successful result for one(). + rows are returned for a query that returns only scalar values + as opposed to full identity-mapped entities. - Calling ``one()`` results in an execution of the underlying query. - As of 0.6, ``one()`` fully fetches all results instead of applying - any kind of limit, so that the "unique"-ing of entities does not - conceal multiple object identities. + Calling :meth:`.Query.one_or_none` results in an execution of the + underlying query. + + .. versionadded:: 1.0.9 + + Added :meth:`.Query.one_or_none` + + .. seealso:: + + :meth:`.Query.first` + + :meth:`.Query.one` """ ret = list(self) - + l = len(ret) if l == 1: return ret[0] elif l == 0: - raise orm_exc.NoResultFound("No row was found for one()") + return None else: + raise orm_exc.MultipleResultsFound( + "Multiple rows were found for one_or_none()") + + def one(self): + """Return exactly one result or raise an exception. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` + if multiple object identities are returned, or if multiple + rows are returned for a query that returns only scalar values + as opposed to full identity-mapped entities. + + Calling :meth:`.one` results in an execution of the underlying query. + + .. seealso:: + + :meth:`.Query.first` + + :meth:`.Query.one_or_none` + + """ + try: + ret = self.one_or_none() + except orm_exc.MultipleResultsFound: raise orm_exc.MultipleResultsFound( "Multiple rows were found for one()") + else: + if ret is None: + raise orm_exc.NoResultFound("No row was found for one()") + return ret def scalar(self): """Return the first element of the first result or None @@ -1450,11 +2854,98 @@ class Query(object): self.session._autoflush() return self._execute_and_instances(context) + def __str__(self): + context = self._compile_context() + try: + bind = self._get_bind_args( + context, self.session.get_bind) if self.session else None + except sa_exc.UnboundExecutionError: + bind = None + return str(context.statement.compile(bind)) + + def _connection_from_session(self, **kw): + conn = self.session.connection(**kw) + if self._execution_options: + conn = conn.execution_options(**self._execution_options) + return conn + def _execute_and_instances(self, querycontext): - result = self.session.execute( - querycontext.statement, params=self._params, - mapper=self._mapper_zero_or_none()) - return self.instances(result, querycontext) + conn = self._get_bind_args( + querycontext, + self._connection_from_session, + close_with_result=True) + + result = conn.execute(querycontext.statement, self._params) + return loading.instances(querycontext.query, result, querycontext) + + def _get_bind_args(self, querycontext, fn, **kw): + return fn( + mapper=self._bind_mapper(), + clause=querycontext.statement, + **kw + ) + + @property + def column_descriptions(self): + """Return metadata about the columns which would be + returned by this :class:`.Query`. + + Format is a list of dictionaries:: + + user_alias = aliased(User, name='user2') + q = sess.query(User, User.id, user_alias) + + # this expression: + q.column_descriptions + + # would return: + [ + { + 'name':'User', + 'type':User, + 'aliased':False, + 'expr':User, + 'entity': User + }, + { + 'name':'id', + 'type':Integer(), + 'aliased':False, + 'expr':User.id, + 'entity': User + }, + { + 'name':'user2', + 'type':User, + 'aliased':True, + 'expr':user_alias, + 'entity': user_alias + } + ] + + """ + + return [ + { + 'name': ent._label_name, + 'type': ent.type, + 'aliased': getattr(insp_ent, 'is_aliased_class', False), + 'expr': ent.expr, + 'entity': + getattr(insp_ent, "entity", None) + if ent.entity_zero is not None + and not insp_ent.is_clause_element + else None + } + for ent, insp_ent in [ + ( + _ent, + (inspect(_ent.entity_zero) + if _ent.entity_zero is not None else None) + ) + for _ent in self._entities + ] + ] def instances(self, cursor, __context=None): """Given a ResultProxy cursor as returned by connection.execute(), @@ -1466,224 +2957,48 @@ class Query(object): for u in session.query(User).instances(result): print u """ - session = self.session - context = __context if context is None: context = QueryContext(self) - context.runid = _new_runid() - - filtered = bool(list(self._mapper_entities)) - single_entity = filtered and len(self._entities) == 1 - - if filtered: - if single_entity: - filter = lambda x: util.unique_list(x, util.IdentitySet) - else: - filter = util.unique_list - else: - filter = None - - custom_rows = single_entity and \ - 'append_result' in self._entities[0].extension - - (process, labels) = \ - zip(*[ - query_entity.row_processor(self, context, custom_rows) - for query_entity in self._entities - ]) - - if not single_entity: - labels = [l for l in labels if l] - - while True: - context.progress = {} - context.partials = {} - - if self._yield_per: - fetch = cursor.fetchmany(self._yield_per) - if not fetch: - break - else: - fetch = cursor.fetchall() - - if custom_rows: - rows = [] - for row in fetch: - process[0](row, rows) - elif single_entity: - rows = [process[0](row, None) for row in fetch] - else: - rows = [util.NamedTuple([proc(row, None) for proc in process], labels) - for row in fetch] - - if filter: - rows = filter(rows) - - if context.refresh_state and self._only_load_props \ - and context.refresh_state in context.progress: - context.refresh_state.commit( - context.refresh_state.dict, self._only_load_props) - context.progress.pop(context.refresh_state) - - session._finalize_loaded(context.progress) - - for ii, (dict_, attrs) in context.partials.iteritems(): - ii.commit(dict_, attrs) - - for row in rows: - yield row - - if not self._yield_per: - break + return loading.instances(self, cursor, context) def merge_result(self, iterator, load=True): - """Merge a result into this Query's Session. - - Given an iterator returned by a Query of the same structure as this one, - return an identical iterator of results, with all mapped instances - merged into the session using Session.merge(). This is an optimized - method which will merge all mapped instances, preserving the structure - of the result rows and unmapped columns with less method overhead than - that of calling Session.merge() explicitly for each value. - - The structure of the results is determined based on the column list - of this Query - if these do not correspond, unchecked errors will occur. - - The 'load' argument is the same as that of Session.merge(). - + """Merge a result into this :class:`.Query` object's Session. + + Given an iterator returned by a :class:`.Query` of the same structure + as this one, return an identical iterator of results, with all mapped + instances merged into the session using :meth:`.Session.merge`. This + is an optimized method which will merge all mapped instances, + preserving the structure of the result rows and unmapped columns with + less method overhead than that of calling :meth:`.Session.merge` + explicitly for each value. + + The structure of the results is determined based on the column list of + this :class:`.Query` - if these do not correspond, unchecked errors + will occur. + + The 'load' argument is the same as that of :meth:`.Session.merge`. + + For an example of how :meth:`~.Query.merge_result` is used, see + the source code for the example :ref:`examples_caching`, where + :meth:`~.Query.merge_result` is used to efficiently restore state + from a cache back into a target :class:`.Session`. + """ - - session = self.session - if load: - # flush current contents if we expect to load data - session._autoflush() - - autoflush = session.autoflush - try: - session.autoflush = False - single_entity = len(self._entities) == 1 - if single_entity: - if isinstance(self._entities[0], _MapperEntity): - result = [session._merge( - attributes.instance_state(instance), - attributes.instance_dict(instance), - load=load, _recursive={}) - for instance in iterator] - else: - result = list(iterator) - else: - mapped_entities = [i for i, e in enumerate(self._entities) - if isinstance(e, _MapperEntity)] - result = [] - for row in iterator: - newrow = list(row) - for i in mapped_entities: - newrow[i] = session._merge( - attributes.instance_state(newrow[i]), - attributes.instance_dict(newrow[i]), - load=load, _recursive={}) - result.append(util.NamedTuple(newrow, row._labels)) - - return iter(result) - finally: - session.autoflush = autoflush - - - def _get(self, key=None, ident=None, refresh_state=None, lockmode=None, - only_load_props=None, passive=None): - lockmode = lockmode or self._lockmode - - mapper = self._mapper_zero() - if not self._populate_existing and \ - not refresh_state and \ - not mapper.always_refresh and \ - lockmode is None: - instance = self.session.identity_map.get(key) - if instance: - # item present in identity map with a different class - if not issubclass(instance.__class__, mapper.class_): - return None - - state = attributes.instance_state(instance) - - # expired - ensure it still exists - if state.expired: - if passive is attributes.PASSIVE_NO_FETCH: - return attributes.PASSIVE_NO_RESULT - try: - state() - except orm_exc.ObjectDeletedError: - self.session._remove_newly_deleted(state) - return None - return instance - elif passive is attributes.PASSIVE_NO_FETCH: - return attributes.PASSIVE_NO_RESULT - if ident is None: - if key is not None: - ident = key[1] - else: - ident = util.to_list(ident) - - if refresh_state is None: - q = self._clone() - q._get_condition() - else: - q = self._clone() - - if ident is not None: - (_get_clause, _get_params) = mapper._get_clause - - # None present in ident - turn those comparisons - # into "IS NULL" - if None in ident: - nones = set([ - _get_params[col].key for col, value in - zip(mapper.primary_key, ident) if value is None - ]) - _get_clause = sql_util.adapt_criterion_to_null( - _get_clause, nones) - - _get_clause = q._adapt_clause(_get_clause, True, False) - q._criterion = _get_clause - - params = dict([ - (_get_params[primary_key].key, id_val) - for id_val, primary_key in zip(ident, mapper.primary_key) - ]) - - if len(params) != len(mapper.primary_key): - raise sa_exc.InvalidRequestError( - "Incorrect number of values in identifier to formulate primary " - "key for query.get(); primary key columns are %s" % - ','.join("'%s'" % c for c in mapper.primary_key)) - - q._params = params - - if lockmode is not None: - q._lockmode = lockmode - q._get_options( - populate_existing=bool(refresh_state), - version_check=(lockmode is not None), - only_load_props=only_load_props, - refresh_state=refresh_state) - q._order_by = None - - try: - return q.one() - except orm_exc.NoResultFound: - return None + return loading.merge_result(self, iterator, load) @property def _select_args(self): return { - 'limit':self._limit, - 'offset':self._offset, - 'distinct':self._distinct, - 'group_by':self._group_by or None, - 'having':self._having + 'limit': self._limit, + 'offset': self._offset, + 'distinct': self._distinct, + 'prefixes': self._prefixes, + 'suffixes': self._suffixes, + 'group_by': self._group_by or None, + 'having': self._having } @property @@ -1693,84 +3008,107 @@ class Query(object): kwargs.get('offset') is not None or kwargs.get('distinct', False)) - def count(self): - """Return a count of rows this Query would return. - - For simple entity queries, count() issues - a SELECT COUNT, and will specifically count the primary - key column of the first entity only. If the query uses - LIMIT, OFFSET, or DISTINCT, count() will wrap the statement - generated by this Query in a subquery, from which a SELECT COUNT - is issued, so that the contract of "how many rows - would be returned?" is honored. - - For queries that request specific columns or expressions, - count() again makes no assumptions about those expressions - and will wrap everything in a subquery. Therefore, - ``Query.count()`` is usually not what you want in this case. - To count specific columns, often in conjunction with - GROUP BY, use ``func.count()`` as an individual column expression - instead of ``Query.count()``. See the ORM tutorial - for an example. + def exists(self): + """A convenience method that turns a query into an EXISTS subquery + of the form EXISTS (SELECT 1 FROM ... WHERE ...). + + e.g.:: + + q = session.query(User).filter(User.name == 'fred') + session.query(q.exists()) + + Producing SQL similar to:: + + SELECT EXISTS ( + SELECT 1 FROM users WHERE users.name = :name_1 + ) AS anon_1 + + The EXISTS construct is usually used in the WHERE clause:: + + session.query(User.id).filter(q.exists()).scalar() + + Note that some databases such as SQL Server don't allow an + EXISTS expression to be present in the columns clause of a + SELECT. To select a simple boolean value based on the exists + as a WHERE, use :func:`.literal`:: + + from sqlalchemy import literal + + session.query(literal(True)).filter(q.exists()).scalar() + + .. versionadded:: 0.8.1 """ - should_nest = [self._should_nest_selectable] - def ent_cols(ent): - if isinstance(ent, _MapperEntity): - return ent.mapper.primary_key - else: - should_nest[0] = True - return [ent.column] - return self._col_aggregate(sql.literal_column('1'), sql.func.count, - nested_cols=chain(*[ent_cols(ent) for ent in self._entities]), - should_nest = should_nest[0] - ) + # .add_columns() for the case that we are a query().select_from(X), + # so that ".statement" can be produced (#2995) but also without + # omitting the FROM clause from a query(X) (#2818); + # .with_only_columns() after we have a core select() so that + # we get just "SELECT 1" without any entities. + return sql.exists(self.add_columns('1').with_labels(). + statement.with_only_columns([1])) - def _col_aggregate(self, col, func, nested_cols=None, should_nest=False): - context = QueryContext(self) + def count(self): + r"""Return a count of rows this Query would return. - for entity in self._entities: - entity.setup_context(self, context) + This generates the SQL for this Query as follows:: - if context.from_clause: - from_obj = list(context.from_clause) - else: - from_obj = context.froms + SELECT count(1) AS count_1 FROM ( + SELECT + ) AS anon_1 - self._adjust_for_single_inheritance(context) + .. versionchanged:: 0.7 + The above scheme is newly refined as of 0.7b3. - whereclause = context.whereclause + For fine grained control over specific columns + to count, to skip the usage of a subquery or + otherwise control of the FROM clause, + or to use other aggregate functions, + use :attr:`~sqlalchemy.sql.expression.func` + expressions in conjunction + with :meth:`~.Session.query`, i.e.:: - if should_nest: - if not nested_cols: - nested_cols = [col] - else: - nested_cols = list(nested_cols) - s = sql.select(nested_cols, whereclause, - from_obj=from_obj, use_labels=True, - **self._select_args) - s = s.alias() - s = sql.select( - [func(s.corresponding_column(col) or col)]).select_from(s) - else: - s = sql.select([func(col)], whereclause, from_obj=from_obj, - **self._select_args) + from sqlalchemy import func - if self._autoflush and not self._populate_existing: - self.session._autoflush() - return self.session.scalar(s, params=self._params, - mapper=self._mapper_zero()) + # count User records, without + # using a subquery. + session.query(func.count(User.id)) + + # return count of user "id" grouped + # by "name" + session.query(func.count(User.id)).\ + group_by(User.name) + + from sqlalchemy import distinct + + # count distinct "name" values + session.query(func.count(distinct(User.name))) + + """ + col = sql.func.count(sql.literal_column('*')) + return self.from_self(col).scalar() def delete(self, synchronize_session='evaluate'): - """Perform a bulk delete query. + r"""Perform a bulk delete query. Deletes rows matched by this query from the database. + E.g.:: + + sess.query(User).filter(User.age == 25).\ + delete(synchronize_session=False) + + sess.query(User).filter(User.age == 25).\ + delete(synchronize_session='evaluate') + + .. warning:: The :meth:`.Query.delete` method is a "bulk" operation, + which bypasses ORM unit-of-work automation in favor of greater + performance. **Please read all caveats and warnings below.** + :param synchronize_session: chooses the strategy for the removal of matched objects from the session. Valid values are: - - False - don't synchronize the session. This option is the most + + ``False`` - don't synchronize the session. This option is the most efficient and is reliable once the session is expired, which typically occurs after a commit(), or explicitly using expire_all(). Before the expiration, objects may still remain in @@ -1778,391 +3116,429 @@ class Query(object): results if they are accessed via get() or already loaded collections. - 'fetch' - performs a select query before the delete to find + ``'fetch'`` - performs a select query before the delete to find objects that are matched by the delete query and need to be removed from the session. Matched objects are removed from the session. - 'evaluate' - Evaluate the query's criteria in Python straight on - the objects in the session. If evaluation of the criteria isn't - implemented, an error is raised. In that case you probably - want to use the 'fetch' strategy as a fallback. - + ``'evaluate'`` - Evaluate the query's criteria in Python straight + on the objects in the session. If evaluation of the criteria isn't + implemented, an error is raised. + The expression evaluator currently doesn't account for differing string collations between the database and Python. - Returns the number of rows deleted, excluding any cascades. + :return: the count of rows matched as returned by the database's + "row count" feature. - The method does *not* offer in-Python cascading of relationships - it is - assumed that ON DELETE CASCADE is configured for any foreign key - references which require it. The Session needs to be expired (occurs - automatically after commit(), or call expire_all()) in order for the - state of dependent objects subject to delete or delete-orphan cascade - to be correctly represented. + .. warning:: **Additional Caveats for bulk query deletes** - Also, the ``before_delete()`` and ``after_delete()`` - :class:`~sqlalchemy.orm.interfaces.MapperExtension` methods are not - called from this method. For a delete hook here, use the - ``after_bulk_delete()`` - :class:`~sqlalchemy.orm.interfaces.MapperExtension` method. + * This method does **not work for joined + inheritance mappings**, since the **multiple table + deletes are not supported by SQL** as well as that the + **join condition of an inheritance mapper is not + automatically rendered**. Care must be taken in any + multiple-table delete to first accommodate via some other means + how the related table will be deleted, as well as to + explicitly include the joining + condition between those tables, even in mappings where + this is normally automatic. E.g. if a class ``Engineer`` + subclasses ``Employee``, a DELETE against the ``Employee`` + table would look like:: + + session.query(Engineer).\ + filter(Engineer.id == Employee.id).\ + filter(Employee.name == 'dilbert').\ + delete() + + However the above SQL will not delete from the Engineer table, + unless an ON DELETE CASCADE rule is established in the database + to handle it. + + Short story, **do not use this method for joined inheritance + mappings unless you have taken the additional steps to make + this feasible**. + + * The polymorphic identity WHERE criteria is **not** included + for single- or + joined- table updates - this must be added **manually** even + for single table inheritance. + + * The method does **not** offer in-Python cascading of + relationships - it is assumed that ON DELETE CASCADE/SET + NULL/etc. is configured for any foreign key references + which require it, otherwise the database may emit an + integrity violation if foreign key references are being + enforced. + + After the DELETE, dependent objects in the + :class:`.Session` which were impacted by an ON DELETE + may not contain the current state, or may have been + deleted. This issue is resolved once the + :class:`.Session` is expired, which normally occurs upon + :meth:`.Session.commit` or can be forced by using + :meth:`.Session.expire_all`. Accessing an expired + object whose row has been deleted will invoke a SELECT + to locate the row; when the row is not found, an + :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is + raised. + + * The ``'fetch'`` strategy results in an additional + SELECT statement emitted and will significantly reduce + performance. + + * The ``'evaluate'`` strategy performs a scan of + all matching objects within the :class:`.Session`; if the + contents of the :class:`.Session` are expired, such as + via a proceeding :meth:`.Session.commit` call, **this will + result in SELECT queries emitted for every matching object**. + + * The :meth:`.MapperEvents.before_delete` and + :meth:`.MapperEvents.after_delete` + events **are not invoked** from this method. Instead, the + :meth:`.SessionEvents.after_bulk_delete` method is provided to + act upon a mass DELETE of entity rows. + + .. seealso:: + + :meth:`.Query.update` + + :ref:`inserts_and_updates` - Core SQL tutorial """ - #TODO: lots of duplication and ifs - probably needs to be refactored to strategies - #TODO: cascades need handling. - if synchronize_session not in [False, 'evaluate', 'fetch']: - raise sa_exc.ArgumentError("Valid strategies for session " - "synchronization are False, 'evaluate' and 'fetch'") - self._no_select_modifiers("delete") + delete_op = persistence.BulkDelete.factory( + self, synchronize_session) + delete_op.exec_() + return delete_op.rowcount - self = self.enable_eagerloads(False) - - context = self._compile_context() - if len(context.statement.froms) != 1 or \ - not isinstance(context.statement.froms[0], schema.Table): - raise sa_exc.ArgumentError("Only deletion via a single table " - "query is currently supported") - primary_table = context.statement.froms[0] - - session = self.session - - if synchronize_session == 'evaluate': - try: - evaluator_compiler = evaluator.EvaluatorCompiler() - if self.whereclause is not None: - eval_condition = evaluator_compiler.process(self.whereclause) - else: - def eval_condition(obj): - return True - - except evaluator.UnevaluatableError: - raise sa_exc.InvalidRequestError("Could not evaluate current criteria in Python. " - "Specify 'fetch' or False for the synchronize_session parameter.") - - delete_stmt = sql.delete(primary_table, context.whereclause) - - if synchronize_session == 'fetch': - #TODO: use RETURNING when available - select_stmt = context.statement.with_only_columns(primary_table.primary_key) - matched_rows = session.execute(select_stmt, params=self._params).fetchall() - - if self._autoflush: - session._autoflush() - result = session.execute(delete_stmt, params=self._params) - - if synchronize_session == 'evaluate': - target_cls = self._mapper_zero().class_ - - #TODO: detect when the where clause is a trivial primary key match - objs_to_expunge = [obj for (cls, pk),obj in session.identity_map.iteritems() - if issubclass(cls, target_cls) and eval_condition(obj)] - for obj in objs_to_expunge: - session._remove_newly_deleted(attributes.instance_state(obj)) - elif synchronize_session == 'fetch': - target_mapper = self._mapper_zero() - for primary_key in matched_rows: - identity_key = target_mapper.identity_key_from_primary_key(list(primary_key)) - if identity_key in session.identity_map: - session._remove_newly_deleted(attributes.instance_state(session.identity_map[identity_key])) - - for ext in session.extensions: - ext.after_bulk_delete(session, self, context, result) - - return result.rowcount - - def update(self, values, synchronize_session='evaluate'): - """Perform a bulk update query. + def update(self, values, synchronize_session='evaluate', update_args=None): + r"""Perform a bulk update query. Updates rows matched by this query in the database. - :param values: a dictionary with attributes names as keys and literal - values or sql expressions as values. + E.g.:: + + sess.query(User).filter(User.age == 25).\ + update({User.age: User.age - 10}, synchronize_session=False) + + sess.query(User).filter(User.age == 25).\ + update({"age": User.age - 10}, synchronize_session='evaluate') + + + .. warning:: The :meth:`.Query.update` method is a "bulk" operation, + which bypasses ORM unit-of-work automation in favor of greater + performance. **Please read all caveats and warnings below.** + + + :param values: a dictionary with attributes names, or alternatively + mapped attributes or SQL expressions, as keys, and literal + values or sql expressions as values. If :ref:`parameter-ordered + mode ` is desired, the values can be + passed as a list of 2-tuples; + this requires that the :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order` + flag is passed to the :paramref:`.Query.update.update_args` dictionary + as well. + + .. versionchanged:: 1.0.0 - string names in the values dictionary + are now resolved against the mapped entity; previously, these + strings were passed as literal column names with no mapper-level + translation. :param synchronize_session: chooses the strategy to update the - attributes on objects in the session. Valid values are: + attributes on objects in the session. Valid values are: - False - don't synchronize the session. This option is the most + ``False`` - don't synchronize the session. This option is the most efficient and is reliable once the session is expired, which typically occurs after a commit(), or explicitly using expire_all(). Before the expiration, updated objects may still remain in the session with stale values on their attributes, which can lead to confusing results. - - 'fetch' - performs a select query before the update to find + + ``'fetch'`` - performs a select query before the update to find objects that are matched by the update query. The updated attributes are expired on matched objects. - 'evaluate' - Evaluate the Query's criteria in Python straight on - the objects in the session. If evaluation of the criteria isn't + ``'evaluate'`` - Evaluate the Query's criteria in Python straight + on the objects in the session. If evaluation of the criteria isn't implemented, an exception is raised. The expression evaluator currently doesn't account for differing string collations between the database and Python. - Returns the number of rows matched by the update. + :param update_args: Optional dictionary, if present will be passed + to the underlying :func:`.update` construct as the ``**kw`` for + the object. May be used to pass dialect-specific arguments such + as ``mysql_limit``, as well as other special arguments such as + :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`. - The method does *not* offer in-Python cascading of relationships - it is assumed that - ON UPDATE CASCADE is configured for any foreign key references which require it. + .. versionadded:: 1.0.0 - The Session needs to be expired (occurs automatically after commit(), or call expire_all()) - in order for the state of dependent objects subject foreign key cascade to be - correctly represented. + :return: the count of rows matched as returned by the database's + "row count" feature. - Also, the ``before_update()`` and ``after_update()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension` - methods are not called from this method. For an update hook here, use the - ``after_bulk_update()`` :class:`~sqlalchemy.orm.interfaces.SessionExtension` method. + .. warning:: **Additional Caveats for bulk query updates** + + * The method does **not** offer in-Python cascading of + relationships - it is assumed that ON UPDATE CASCADE is + configured for any foreign key references which require + it, otherwise the database may emit an integrity + violation if foreign key references are being enforced. + + After the UPDATE, dependent objects in the + :class:`.Session` which were impacted by an ON UPDATE + CASCADE may not contain the current state; this issue is + resolved once the :class:`.Session` is expired, which + normally occurs upon :meth:`.Session.commit` or can be + forced by using :meth:`.Session.expire_all`. + + * The ``'fetch'`` strategy results in an additional + SELECT statement emitted and will significantly reduce + performance. + + * The ``'evaluate'`` strategy performs a scan of + all matching objects within the :class:`.Session`; if the + contents of the :class:`.Session` are expired, such as + via a proceeding :meth:`.Session.commit` call, **this will + result in SELECT queries emitted for every matching object**. + + * The method supports multiple table updates, as detailed + in :ref:`multi_table_updates`, and this behavior does + extend to support updates of joined-inheritance and + other multiple table mappings. However, the **join + condition of an inheritance mapper is not + automatically rendered**. Care must be taken in any + multiple-table update to explicitly include the joining + condition between those tables, even in mappings where + this is normally automatic. E.g. if a class ``Engineer`` + subclasses ``Employee``, an UPDATE of the ``Engineer`` + local table using criteria against the ``Employee`` + local table might look like:: + + session.query(Engineer).\ + filter(Engineer.id == Employee.id).\ + filter(Employee.name == 'dilbert').\ + update({"engineer_type": "programmer"}) + + * The polymorphic identity WHERE criteria is **not** included + for single- or + joined- table updates - this must be added **manually**, even + for single table inheritance. + + * The :meth:`.MapperEvents.before_update` and + :meth:`.MapperEvents.after_update` + events **are not invoked from this method**. Instead, the + :meth:`.SessionEvents.after_bulk_update` method is provided to + act upon a mass UPDATE of entity rows. + + .. seealso:: + + :meth:`.Query.delete` + + :ref:`inserts_and_updates` - Core SQL tutorial """ - #TODO: value keys need to be mapped to corresponding sql cols and instr.attr.s to string keys - #TODO: updates of manytoone relationships need to be converted to fk assignments - #TODO: cascades need handling. - - if synchronize_session == 'expire': - util.warn_deprecated("The 'expire' value as applied to " - "the synchronize_session argument of " - "query.update() is now called 'fetch'") - synchronize_session = 'fetch' - - if synchronize_session not in [False, 'evaluate', 'fetch']: - raise sa_exc.ArgumentError("Valid strategies for session synchronization are False, 'evaluate' and 'fetch'") - self._no_select_modifiers("update") - - self = self.enable_eagerloads(False) - - context = self._compile_context() - if len(context.statement.froms) != 1 or not isinstance(context.statement.froms[0], schema.Table): - raise sa_exc.ArgumentError("Only update via a single table query is currently supported") - primary_table = context.statement.froms[0] - - session = self.session - - if synchronize_session == 'evaluate': - try: - evaluator_compiler = evaluator.EvaluatorCompiler() - if self.whereclause is not None: - eval_condition = evaluator_compiler.process(self.whereclause) - else: - def eval_condition(obj): - return True - - value_evaluators = {} - for key,value in values.iteritems(): - key = expression._column_as_key(key) - value_evaluators[key] = evaluator_compiler.process(expression._literal_as_binds(value)) - except evaluator.UnevaluatableError: - raise sa_exc.InvalidRequestError("Could not evaluate current criteria in Python. " - "Specify 'fetch' or False for the synchronize_session parameter.") - - update_stmt = sql.update(primary_table, context.whereclause, values) - - if synchronize_session == 'fetch': - select_stmt = context.statement.with_only_columns(primary_table.primary_key) - matched_rows = session.execute(select_stmt, params=self._params).fetchall() - - if self._autoflush: - session._autoflush() - result = session.execute(update_stmt, params=self._params) - - if synchronize_session == 'evaluate': - target_cls = self._mapper_zero().class_ - - for (cls, pk),obj in session.identity_map.iteritems(): - evaluated_keys = value_evaluators.keys() - - if issubclass(cls, target_cls) and eval_condition(obj): - state, dict_ = attributes.instance_state(obj), attributes.instance_dict(obj) - - # only evaluate unmodified attributes - to_evaluate = state.unmodified.intersection(evaluated_keys) - for key in to_evaluate: - dict_[key] = value_evaluators[key](obj) - - state.commit(dict_, list(to_evaluate)) - - # expire attributes with pending changes - # (there was no autoflush, so they are overwritten) - state.expire_attributes(dict_, set(evaluated_keys).difference(to_evaluate)) - - elif synchronize_session == 'fetch': - target_mapper = self._mapper_zero() - - for primary_key in matched_rows: - identity_key = target_mapper.identity_key_from_primary_key(list(primary_key)) - if identity_key in session.identity_map: - session.expire( - session.identity_map[identity_key], - [expression._column_as_key(k) for k in values] - ) - - for ext in session.extensions: - ext.after_bulk_update(session, self, context, result) - - return result.rowcount + update_args = update_args or {} + update_op = persistence.BulkUpdate.factory( + self, synchronize_session, values, update_args) + update_op.exec_() + return update_op.rowcount def _compile_context(self, labels=True): + if self.dispatch.before_compile: + for fn in self.dispatch.before_compile: + new_query = fn(self) + if new_query is not None: + self = new_query + context = QueryContext(self) if context.statement is not None: return context - if self._lockmode: - try: - for_update = {'read': 'read', - 'update': True, - 'update_nowait': 'nowait', - None: False}[self._lockmode] - except KeyError: - raise sa_exc.ArgumentError("Unknown lockmode %r" % self._lockmode) - else: - for_update = False + context.labels = labels + + context._for_update_arg = self._for_update_arg for entity in self._entities: entity.setup_context(self, context) - + for rec in context.create_eager_joins: strategy = rec[0] strategy(*rec[1:]) - - eager_joins = context.eager_joins.values() if context.from_clause: - froms = list(context.from_clause) # "load from explicit FROMs" mode, - # i.e. when select_from() or join() is used - else: - froms = context.froms # "load from discrete FROMs" mode, - # i.e. when each _MappedEntity has its own FROM + # "load from explicit FROMs" mode, + # i.e. when select_from() or join() is used + context.froms = list(context.from_clause) + # else "load from discrete FROMs" mode, + # i.e. when each _MappedEntity has its own FROM - self._adjust_for_single_inheritance(context) + if self._enable_single_crit: + self._adjust_for_single_inheritance(context) if not context.primary_columns: if self._only_load_props: raise sa_exc.InvalidRequestError( - "No column-based properties specified for refresh operation." - " Use session.expire() to reload collections and related items.") + "No column-based properties specified for " + "refresh operation. Use session.expire() " + "to reload collections and related items.") else: raise sa_exc.InvalidRequestError( - "Query contains no columns with which to SELECT from.") + "Query contains no columns with which to " + "SELECT from.") if context.multi_row_eager_loaders and self._should_nest_selectable: - # for eager joins present and LIMIT/OFFSET/DISTINCT, - # wrap the query inside a select, - # then append eager joins onto that - - if context.order_by: - order_by_col_expr = list( - chain(*[ - sql_util.find_columns(o) - for o in context.order_by - ]) - ) - else: - context.order_by = None - order_by_col_expr = [] - - inner = sql.select( - context.primary_columns + order_by_col_expr, - context.whereclause, - from_obj=froms, - use_labels=labels, - correlate=False, - order_by=context.order_by, - **self._select_args - ) - - for hint in self._with_hints: - inner = inner.with_hint(*hint) - - if self._correlate: - inner = inner.correlate(*self._correlate) - - inner = inner.alias() - - equivs = self.__all_equivs() - - context.adapter = sql_util.ColumnAdapter(inner, equivs) - - statement = sql.select( - [inner] + context.secondary_columns, - for_update=for_update, - use_labels=labels) - - if self._execution_options: - statement = statement.execution_options(**self._execution_options) - - from_clause = inner - for eager_join in eager_joins: - # EagerLoader places a 'stop_on' attribute on the join, - # giving us a marker as to where the "splice point" of the join should be - from_clause = sql_util.splice_joins(from_clause, eager_join, eager_join.stop_on) - - statement.append_from(from_clause) - - if context.order_by: - statement.append_order_by(*context.adapter.copy_and_process(context.order_by)) - - statement.append_order_by(*context.eager_order_by) + context.statement = self._compound_eager_statement(context) else: - if not context.order_by: - context.order_by = None - - if self._distinct and context.order_by: - order_by_col_expr = list( - chain(*[ - sql_util.find_columns(o) - for o in context.order_by - ]) - ) - context.primary_columns += order_by_col_expr - - froms += tuple(context.eager_joins.values()) - - statement = sql.select( - context.primary_columns + context.secondary_columns, - context.whereclause, - from_obj=froms, - use_labels=labels, - for_update=for_update, - correlate=False, - order_by=context.order_by, - **self._select_args - ) - - for hint in self._with_hints: - statement = statement.with_hint(*hint) - - if self._execution_options: - statement = statement.execution_options(**self._execution_options) - - if self._correlate: - statement = statement.correlate(*self._correlate) - - if context.eager_order_by: - statement.append_order_by(*context.eager_order_by) - - context.statement = statement + context.statement = self._simple_statement(context) return context + def _compound_eager_statement(self, context): + # for eager joins present and LIMIT/OFFSET/DISTINCT, + # wrap the query inside a select, + # then append eager joins onto that + + if context.order_by: + order_by_col_expr = \ + sql_util.expand_column_list_from_order_by( + context.primary_columns, + context.order_by + ) + else: + context.order_by = None + order_by_col_expr = [] + + inner = sql.select( + context.primary_columns + order_by_col_expr, + context.whereclause, + from_obj=context.froms, + use_labels=context.labels, + # TODO: this order_by is only needed if + # LIMIT/OFFSET is present in self._select_args, + # else the application on the outside is enough + order_by=context.order_by, + **self._select_args + ) + + for hint in self._with_hints: + inner = inner.with_hint(*hint) + + if self._correlate: + inner = inner.correlate(*self._correlate) + + inner = inner.alias() + + equivs = self.__all_equivs() + + context.adapter = sql_util.ColumnAdapter(inner, equivs) + + statement = sql.select( + [inner] + context.secondary_columns, + use_labels=context.labels) + + statement._for_update_arg = context._for_update_arg + + from_clause = inner + for eager_join in context.eager_joins.values(): + # EagerLoader places a 'stop_on' attribute on the join, + # giving us a marker as to where the "splice point" of + # the join should be + from_clause = sql_util.splice_joins( + from_clause, + eager_join, eager_join.stop_on) + + statement.append_from(from_clause) + + if context.order_by: + statement.append_order_by( + *context.adapter.copy_and_process( + context.order_by + ) + ) + + statement.append_order_by(*context.eager_order_by) + return statement + + def _simple_statement(self, context): + if not context.order_by: + context.order_by = None + + if self._distinct is True and context.order_by: + context.primary_columns += \ + sql_util.expand_column_list_from_order_by( + context.primary_columns, + context.order_by + ) + context.froms += tuple(context.eager_joins.values()) + + statement = sql.select( + context.primary_columns + + context.secondary_columns, + context.whereclause, + from_obj=context.froms, + use_labels=context.labels, + order_by=context.order_by, + **self._select_args + ) + statement._for_update_arg = context._for_update_arg + + for hint in self._with_hints: + statement = statement.with_hint(*hint) + + if self._correlate: + statement = statement.correlate(*self._correlate) + + if context.eager_order_by: + statement.append_order_by(*context.eager_order_by) + return statement + def _adjust_for_single_inheritance(self, context): """Apply single-table-inheritance filtering. - For all distinct single-table-inheritance mappers represented in the - columns clause of this query, add criterion to the WHERE clause of the - given QueryContext such that only the appropriate subtypes are - selected from the total results. + For all distinct single-table-inheritance mappers represented in + the columns clause of this query, add criterion to the WHERE + clause of the given QueryContext such that only the appropriate + subtypes are selected from the total results. """ - for entity, (mapper, adapter, s, i, w) in self._mapper_adapter_map.iteritems(): - single_crit = mapper._single_table_criterion + for (ext_info, adapter) in set(self._mapper_adapter_map.values()): + if ext_info in self._join_entities: + continue + single_crit = ext_info.mapper._single_table_criterion if single_crit is not None: if adapter: single_crit = adapter.traverse(single_crit) single_crit = self._adapt_clause(single_crit, False, False) - context.whereclause = sql.and_(context.whereclause, single_crit) + context.whereclause = sql.and_( + sql.True_._ifnone(context.whereclause), + single_crit) - def __str__(self): - return str(self._compile_context().statement) + +from ..sql.selectable import ForUpdateArg + + +class LockmodeArg(ForUpdateArg): + @classmethod + def parse_legacy_query(self, mode): + if mode in (None, False): + return None + + if mode == "read": + read = True + nowait = False + elif mode == "update": + read = nowait = False + elif mode == "update_nowait": + nowait = True + read = False + else: + raise sa_exc.ArgumentError( + "Unknown with_lockmode argument: %r" % mode) + + return LockmodeArg(read=read, nowait=nowait) class _QueryEntity(object): @@ -2171,8 +3547,11 @@ class _QueryEntity(object): def __new__(cls, *args, **kwargs): if cls is _QueryEntity: entity = args[1] - if not isinstance(entity, basestring) and _is_mapped_class(entity): + if not isinstance(entity, util.string_types) and \ + _is_mapped_class(entity): cls = _MapperEntity + elif isinstance(entity, Bundle): + cls = _BundleEntity else: cls = _ColumnEntity return object.__new__(cls) @@ -2182,63 +3561,102 @@ class _QueryEntity(object): q.__dict__ = self.__dict__.copy() return q + class _MapperEntity(_QueryEntity): """mapper/class/AliasedClass entity""" def __init__(self, query, entity): - self.primary_entity = not query._entities + if not query._primary_entity: + query._primary_entity = self query._entities.append(self) - + query._has_mapper_entities = True self.entities = [entity] - self.entity_zero = entity + self.expr = entity - def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic): - self.mapper = mapper - self.extension = self.mapper.extension - self.adapter = adapter - self.selectable = from_obj - self._with_polymorphic = with_polymorphic - self._polymorphic_discriminator = None - self.is_aliased_class = is_aliased_class - if is_aliased_class: - self.path_entity = self.entity = self.entity_zero = entity + supports_single_entity = True + + use_id_for_hash = True + + def setup_entity(self, ext_info, aliased_adapter): + self.mapper = ext_info.mapper + self.aliased_adapter = aliased_adapter + self.selectable = ext_info.selectable + self.is_aliased_class = ext_info.is_aliased_class + self._with_polymorphic = ext_info.with_polymorphic_mappers + self._polymorphic_discriminator = \ + ext_info.polymorphic_on + self.entity_zero = ext_info + if ext_info.is_aliased_class: + self._label_name = self.entity_zero.name else: - self.path_entity = mapper - self.entity = self.entity_zero = mapper + self._label_name = self.mapper.class_.__name__ + self.path = self.entity_zero._path_registry + + def set_with_polymorphic(self, query, cls_or_mappers, + selectable, polymorphic_on): + """Receive an update from a call to query.with_polymorphic(). + + Note the newer style of using a free standing with_polymporphic() + construct doesn't make use of this method. + + + """ + if self.is_aliased_class: + # TODO: invalidrequest ? + raise NotImplementedError( + "Can't use with_polymorphic() against " + "an Aliased object" + ) - def set_with_polymorphic(self, query, cls_or_mappers, selectable, discriminator): if cls_or_mappers is None: query._reset_polymorphic_adapter(self.mapper) return - mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable) + mappers, from_obj = self.mapper._with_polymorphic_args( + cls_or_mappers, selectable) self._with_polymorphic = mappers - self._polymorphic_discriminator = discriminator + self._polymorphic_discriminator = polymorphic_on - # TODO: do the wrapped thing here too so that with_polymorphic() can be - # applied to aliases - if not self.is_aliased_class: - self.selectable = from_obj - self.adapter = query._get_polymorphic_adapter(self, from_obj) + self.selectable = from_obj + query._mapper_loads_polymorphically_with( + self.mapper, sql_util.ColumnAdapter( + from_obj, self.mapper._equivalent_columns)) + + @property + def type(self): + return self.mapper.class_ + + @property + def entity_zero_or_selectable(self): + return self.entity_zero def corresponds_to(self, entity): - if _is_aliased_class(entity) or self.is_aliased_class: - return entity is self.path_entity - else: - return entity.common_parent(self.path_entity) + if entity.is_aliased_class: + if self.is_aliased_class: + if entity._base_alias is self.entity_zero._base_alias: + return True + return False + elif self.is_aliased_class: + if self.entity_zero._use_mapper_path: + return entity in self._with_polymorphic + else: + return entity is self.entity_zero + + return entity.common_parent(self.entity_zero) def adapt_to_selectable(self, query, sel): query._entities.append(self) def _get_entity_clauses(self, query, context): - - adapter = None - if not self.is_aliased_class and query._polymorphic_adapters: - adapter = query._polymorphic_adapters.get(self.mapper, None) - if not adapter and self.adapter: - adapter = self.adapter - + adapter = None + + if not self.is_aliased_class: + if query._polymorphic_adapters: + adapter = query._polymorphic_adapters.get(self.mapper, None) + else: + adapter = self.aliased_adapter + if adapter: if query._from_obj_alias: ret = adapter.wrap(query._from_obj_alias) @@ -2249,7 +3667,7 @@ class _MapperEntity(_QueryEntity): return ret - def row_processor(self, query, context, custom_rows): + def row_processor(self, query, context, result): adapter = self._get_entity_clauses(query, context) if context.adapter and adapter: @@ -2257,30 +3675,37 @@ class _MapperEntity(_QueryEntity): elif not adapter: adapter = context.adapter - # polymorphic mappers which have concrete tables in their hierarchy usually + # polymorphic mappers which have concrete tables in + # their hierarchy usually # require row aliasing unconditionally. if not adapter and self.mapper._requires_row_aliasing: - adapter = sql_util.ColumnAdapter(self.selectable, self.mapper._equivalent_columns) + adapter = sql_util.ColumnAdapter( + self.selectable, + self.mapper._equivalent_columns) - if self.primary_entity: - _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter, - extension=self.extension, only_load_props=query._only_load_props, refresh_state=context.refresh_state, - polymorphic_discriminator=self._polymorphic_discriminator - ) + if query._primary_entity is self: + only_load_props = query._only_load_props + refresh_state = context.refresh_state else: - _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter, - polymorphic_discriminator=self._polymorphic_discriminator) + only_load_props = refresh_state = None - if self.is_aliased_class: - entname = self.entity._sa_label_name - else: - entname = self.mapper.class_.__name__ - - return _instance, entname + _instance = loading._instance_processor( + self.mapper, + context, + result, + self.path, + adapter, + only_load_props=only_load_props, + refresh_state=refresh_state, + polymorphic_discriminator=self._polymorphic_discriminator + ) + + return _instance, self._label_name def setup_context(self, query, context): adapter = self._get_entity_clauses(query, context) + # if self._adapted_selectable is None: context.froms += (self.selectable,) if context.order_by is False and self.mapper.order_by: @@ -2288,64 +3713,278 @@ class _MapperEntity(_QueryEntity): # apply adaptation to the mapper's order_by if needed. if adapter: - context.order_by = adapter.adapt_list(util.to_list(context.order_by)) + context.order_by = adapter.adapt_list( + util.to_list( + context.order_by + ) + ) - for value in self.mapper._iterate_polymorphic_properties(self._with_polymorphic): - if query._only_load_props and value.key not in query._only_load_props: - continue - value.setup( - context, - self, - (self.path_entity,), - adapter, - only_load_props=query._only_load_props, - column_collection=context.primary_columns - ) - - if self._polymorphic_discriminator is not None: - if adapter: - pd = adapter.columns[self._polymorphic_discriminator] - else: - pd = self._polymorphic_discriminator - context.primary_columns.append(pd) + loading._setup_entity_query( + context, self.mapper, self, + self.path, adapter, context.primary_columns, + with_polymorphic=self._with_polymorphic, + only_load_props=query._only_load_props, + polymorphic_discriminator=self._polymorphic_discriminator) def __str__(self): return str(self.mapper) + +@inspection._self_inspects +class Bundle(InspectionAttr): + """A grouping of SQL expressions that are returned by a :class:`.Query` + under one namespace. + + The :class:`.Bundle` essentially allows nesting of the tuple-based + results returned by a column-oriented :class:`.Query` object. It also + is extensible via simple subclassing, where the primary capability + to override is that of how the set of expressions should be returned, + allowing post-processing as well as custom return types, without + involving ORM identity-mapped classes. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`bundles` + + """ + + single_entity = False + """If True, queries for a single Bundle will be returned as a single + entity, rather than an element within a keyed tuple.""" + + is_clause_element = False + + is_mapper = False + + is_aliased_class = False + + def __init__(self, name, *exprs, **kw): + r"""Construct a new :class:`.Bundle`. + + e.g.:: + + bn = Bundle("mybundle", MyClass.x, MyClass.y) + + for row in session.query(bn).filter( + bn.c.x == 5).filter(bn.c.y == 4): + print(row.mybundle.x, row.mybundle.y) + + :param name: name of the bundle. + :param \*exprs: columns or SQL expressions comprising the bundle. + :param single_entity=False: if True, rows for this :class:`.Bundle` + can be returned as a "single entity" outside of any enclosing tuple + in the same manner as a mapped entity. + + """ + self.name = self._label = name + self.exprs = exprs + self.c = self.columns = ColumnCollection() + self.columns.update((getattr(col, "key", col._label), col) + for col in exprs) + self.single_entity = kw.pop('single_entity', self.single_entity) + + columns = None + """A namespace of SQL expressions referred to by this :class:`.Bundle`. + + e.g.:: + + bn = Bundle("mybundle", MyClass.x, MyClass.y) + + q = sess.query(bn).filter(bn.c.x == 5) + + Nesting of bundles is also supported:: + + b1 = Bundle("b1", + Bundle('b2', MyClass.a, MyClass.b), + Bundle('b3', MyClass.x, MyClass.y) + ) + + q = sess.query(b1).filter( + b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9) + + .. seealso:: + + :attr:`.Bundle.c` + + """ + + c = None + """An alias for :attr:`.Bundle.columns`.""" + + def _clone(self): + cloned = self.__class__.__new__(self.__class__) + cloned.__dict__.update(self.__dict__) + return cloned + + def __clause_element__(self): + return expression.ClauseList(group=False, *self.c) + + @property + def clauses(self): + return self.__clause_element__().clauses + + def label(self, name): + """Provide a copy of this :class:`.Bundle` passing a new label.""" + + cloned = self._clone() + cloned.name = name + return cloned + + def create_row_processor(self, query, procs, labels): + """Produce the "row processing" function for this :class:`.Bundle`. + + May be overridden by subclasses. + + .. seealso:: + + :ref:`bundles` - includes an example of subclassing. + + """ + keyed_tuple = util.lightweight_named_tuple('result', labels) + + def proc(row): + return keyed_tuple([proc(row) for proc in procs]) + return proc + + +class _BundleEntity(_QueryEntity): + use_id_for_hash = False + + def __init__(self, query, bundle, setup_entities=True): + query._entities.append(self) + self.bundle = self.expr = bundle + self.type = type(bundle) + self._label_name = bundle.name + self._entities = [] + + if setup_entities: + for expr in bundle.exprs: + if isinstance(expr, Bundle): + _BundleEntity(self, expr) + else: + _ColumnEntity(self, expr, namespace=self) + + self.supports_single_entity = self.bundle.single_entity + + @property + def entities(self): + entities = [] + for ent in self._entities: + entities.extend(ent.entities) + return entities + + @property + def entity_zero(self): + for ent in self._entities: + ezero = ent.entity_zero + if ezero is not None: + return ezero + else: + return None + + def corresponds_to(self, entity): + # TODO: this seems to have no effect for + # _ColumnEntity either + return False + + @property + def entity_zero_or_selectable(self): + for ent in self._entities: + ezero = ent.entity_zero_or_selectable + if ezero is not None: + return ezero + else: + return None + + def adapt_to_selectable(self, query, sel): + c = _BundleEntity(query, self.bundle, setup_entities=False) + # c._label_name = self._label_name + # c.entity_zero = self.entity_zero + # c.entities = self.entities + + for ent in self._entities: + ent.adapt_to_selectable(c, sel) + + def setup_entity(self, ext_info, aliased_adapter): + for ent in self._entities: + ent.setup_entity(ext_info, aliased_adapter) + + def setup_context(self, query, context): + for ent in self._entities: + ent.setup_context(query, context) + + def row_processor(self, query, context, result): + procs, labels = zip( + *[ent.row_processor(query, context, result) + for ent in self._entities] + ) + + proc = self.bundle.create_row_processor(query, procs, labels) + + return proc, self._label_name + + class _ColumnEntity(_QueryEntity): """Column/expression based entity.""" - def __init__(self, query, column): - if isinstance(column, basestring): + def __init__(self, query, column, namespace=None): + self.expr = column + self.namespace = namespace + search_entities = True + check_column = False + + if isinstance(column, util.string_types): column = sql.literal_column(column) - self._result_label = column.name - elif isinstance(column, attributes.QueryableAttribute): - self._result_label = column.key - column = column.__clause_element__() - else: - self._result_label = getattr(column, 'key', None) - - if not isinstance(column, expression.ColumnElement) and hasattr(column, '_select_iterable'): - for c in column._select_iterable: - if c is column: - break - _ColumnEntity(query, c) - - if c is not column: + self._label_name = column.name + search_entities = False + check_column = True + _entity = None + elif isinstance(column, ( + attributes.QueryableAttribute, + interfaces.PropComparator + )): + _entity = getattr(column, '_parententity', None) + if _entity is not None: + search_entities = False + self._label_name = column.key + column = column._query_clause_element() + check_column = True + if isinstance(column, Bundle): + _BundleEntity(query, column) return if not isinstance(column, sql.ColumnElement): - raise sa_exc.InvalidRequestError( - "SQL expression, column, or mapped entity expected - got '%r'" % column - ) + if hasattr(column, '_select_iterable'): + # break out an object like Table into + # individual columns + for c in column._select_iterable: + if c is column: + break + _ColumnEntity(query, c, namespace=column) + else: + return - # if the Column is unnamed, give it a + raise sa_exc.InvalidRequestError( + "SQL expression, column, or mapped entity " + "expected - got '%r'" % (column, ) + ) + elif not check_column: + self._label_name = getattr(column, 'key', None) + search_entities = True + + self.type = type_ = column.type + self.use_id_for_hash = not type_.hashable + + # If the Column is unnamed, give it a # label() so that mutable column expressions # can be located in the result even # if the expression's identity has been changed - # due to adaption - if not column._label: - column = column.label(None) + # due to adaption. + + if not column._label and not getattr(column, 'is_literal', False): + column = column.label(self._label_name) query._entities.append(self) @@ -2354,73 +3993,125 @@ class _ColumnEntity(_QueryEntity): # look for ORM entities represented within the # given expression. Try to count only entities - # for columns whos FROM object is in the actual list + # for columns whose FROM object is in the actual list # of FROMs for the overall expression - this helps # subqueries which were built from ORM constructs from # leaking out their entities into the main select construct - actual_froms = set(column._from_objects) + self.actual_froms = actual_froms = set(column._from_objects) - self.entities = util.OrderedSet( - elem._annotations['parententity'] - for elem in visitors.iterate(column, {}) - if 'parententity' in elem._annotations - and actual_froms.intersection(elem._from_objects) - ) - - if self.entities: - self.entity_zero = list(self.entities)[0] + if not search_entities: + self.entity_zero = _entity + if _entity: + self.entities = [_entity] + self.mapper = _entity.mapper + else: + self.entities = [] + self.mapper = None + self._from_entities = set(self.entities) else: - self.entity_zero = None - + all_elements = [ + elem for elem in sql_util.surface_column_elements(column) + if 'parententity' in elem._annotations + ] + + self.entities = util.unique_list([ + elem._annotations['parententity'] + for elem in all_elements + if 'parententity' in elem._annotations + ]) + + self._from_entities = set([ + elem._annotations['parententity'] + for elem in all_elements + if 'parententity' in elem._annotations + and actual_froms.intersection(elem._from_objects) + ]) + if self.entities: + self.entity_zero = self.entities[0] + self.mapper = self.entity_zero.mapper + elif self.namespace is not None: + self.entity_zero = self.namespace + self.mapper = None + else: + self.entity_zero = None + self.mapper = None + + supports_single_entity = False + + @property + def entity_zero_or_selectable(self): + if self.entity_zero is not None: + return self.entity_zero + elif self.actual_froms: + return list(self.actual_froms)[0] + else: + return None + def adapt_to_selectable(self, query, sel): - _ColumnEntity(query, sel.corresponding_column(self.column)) - - def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic): - self.selectable = from_obj - self.froms.add(from_obj) + c = _ColumnEntity(query, sel.corresponding_column(self.column)) + c._label_name = self._label_name + c.entity_zero = self.entity_zero + c.entities = self.entities + + def setup_entity(self, ext_info, aliased_adapter): + if 'selectable' not in self.__dict__: + self.selectable = ext_info.selectable + + if self.actual_froms.intersection(ext_info.selectable._from_objects): + self.froms.add(ext_info.selectable) def corresponds_to(self, entity): + # TODO: just returning False here, + # no tests fail if self.entity_zero is None: return False elif _is_aliased_class(entity): + # TODO: polymorphic subclasses ? return entity is self.entity_zero else: return not _is_aliased_class(self.entity_zero) and \ - entity.common_parent(self.entity_zero) + entity.common_parent(self.entity_zero) - def _resolve_expr_against_query_aliases(self, query, expr, context): - return query._adapt_clause(expr, False, True) - - def row_processor(self, query, context, custom_rows): - column = self._resolve_expr_against_query_aliases(query, self.column, context) + def row_processor(self, query, context, result): + if ('fetch_column', self) in context.attributes: + column = context.attributes[('fetch_column', self)] + else: + column = query._adapt_clause(self.column, False, True) if context.adapter: column = context.adapter.columns[column] - def proc(row, result): - return row[column] - - return (proc, self._result_label) + getter = result._getter(column) + return getter, self._label_name def setup_context(self, query, context): - column = self._resolve_expr_against_query_aliases(query, self.column, context) + column = query._adapt_clause(self.column, False, True) context.froms += tuple(self.froms) context.primary_columns.append(column) + context.attributes[('fetch_column', self)] = column + def __str__(self): return str(self.column) -log.class_logger(Query) class QueryContext(object): - multi_row_eager_loaders = False - adapter = None - froms = () - + __slots__ = ( + 'multi_row_eager_loaders', 'adapter', 'froms', 'for_update', + 'query', 'session', 'autoflush', 'populate_existing', + 'invoke_all_eagers', 'version_check', 'refresh_state', + 'primary_columns', 'secondary_columns', 'eager_order_by', + 'eager_joins', 'create_eager_joins', 'propagate_options', + 'attributes', 'statement', 'from_clause', 'whereclause', + 'order_by', 'labels', '_for_update_arg', 'runid', 'partials' + ) + def __init__(self, query): if query._statement is not None: - if isinstance(query._statement, expression._SelectBaseMixin) and not query._statement.use_labels: + if isinstance(query._statement, expression.SelectBase) and \ + not query._statement._textual and \ + not query._statement.use_labels: self.statement = query._statement.apply_labels() else: self.statement = query._statement @@ -2430,9 +4121,15 @@ class QueryContext(object): self.whereclause = query._criterion self.order_by = query._order_by + self.multi_row_eager_loaders = False + self.adapter = None + self.froms = () + self.for_update = None self.query = query self.session = query.session + self.autoflush = query._autoflush self.populate_existing = query._populate_existing + self.invoke_all_eagers = query._invoke_all_eagers self.version_check = query._version_check self.refresh_state = query._refresh_state self.primary_columns = [] @@ -2440,30 +4137,51 @@ class QueryContext(object): self.eager_order_by = [] self.eager_joins = {} self.create_eager_joins = [] - self.propagate_options = set(o for o in query._with_options if o.propagate_to_loaders) + self.propagate_options = set(o for o in query._with_options if + o.propagate_to_loaders) self.attributes = query._attributes.copy() + class AliasOption(interfaces.MapperOption): def __init__(self, alias): + r"""Return a :class:`.MapperOption` that will indicate to the :class:`.Query` + that the main table has been aliased. + + This is a seldom-used option to suit the + very rare case that :func:`.contains_eager` + is being used in conjunction with a user-defined SELECT + statement that aliases the parent table. E.g.:: + + # define an aliased UNION called 'ulist' + ulist = users.select(users.c.user_id==7).\ + union(users.select(users.c.user_id>7)).\ + alias('ulist') + + # add on an eager load of "addresses" + statement = ulist.outerjoin(addresses).\ + select().apply_labels() + + # create query, indicating "ulist" will be an + # alias for the main table, "addresses" + # property should be eager loaded + query = session.query(User).options( + contains_alias(ulist), + contains_eager(User.addresses)) + + # then get results via the statement + results = query.from_statement(statement).all() + + :param alias: is the string name of an alias, or a + :class:`~.sql.expression.Alias` object representing + the alias. + + """ self.alias = alias def process_query(self, query): - if isinstance(self.alias, basestring): + if isinstance(self.alias, util.string_types): alias = query._mapper_zero().mapped_table.alias(self.alias) else: alias = self.alias query._from_obj_alias = sql_util.ColumnAdapter(alias) - - -_runid = 1L -_id_lock = util.threading.Lock() - -def _new_runid(): - global _runid - _id_lock.acquire() - try: - _runid += 1 - return _runid - finally: - _id_lock.release() diff --git a/sqlalchemy/orm/scoping.py b/sqlalchemy/orm/scoping.py index 40bbb32..05b8813 100644 --- a/sqlalchemy/orm/scoping.py +++ b/sqlalchemy/orm/scoping.py @@ -1,96 +1,120 @@ -# scoping.py -# Copyright (C) the SQLAlchemy authors and contributors +# orm/scoping.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sqlalchemy.exceptions as sa_exc -from sqlalchemy.util import ScopedRegistry, ThreadLocalRegistry, \ - to_list, get_cls_kwargs, deprecated -from sqlalchemy.orm import ( - EXT_CONTINUE, MapperExtension, class_mapper, object_session - ) -from sqlalchemy.orm import exc as orm_exc -from sqlalchemy.orm.session import Session +from .. import exc as sa_exc +from ..util import ScopedRegistry, ThreadLocalRegistry, warn +from . import class_mapper, exc as orm_exc +from .session import Session -__all__ = ['ScopedSession'] +__all__ = ['scoped_session'] -class ScopedSession(object): - """Provides thread-local management of Sessions. +class scoped_session(object): + """Provides scoped management of :class:`.Session` objects. - Usage:: - - Session = scoped_session(sessionmaker(autoflush=True)) - - ... use session normally. + See :ref:`unitofwork_contextual` for a tutorial. """ + session_factory = None + """The `session_factory` provided to `__init__` is stored in this + attribute and may be accessed at a later time. This can be useful when + a new non-scoped :class:`.Session` or :class:`.Connection` to the + database is needed.""" + def __init__(self, session_factory, scopefunc=None): + """Construct a new :class:`.scoped_session`. + + :param session_factory: a factory to create new :class:`.Session` + instances. This is usually, but not necessarily, an instance + of :class:`.sessionmaker`. + :param scopefunc: optional function which defines + the current scope. If not passed, the :class:`.scoped_session` + object assumes "thread-local" scope, and will use + a Python ``threading.local()`` in order to maintain the current + :class:`.Session`. If passed, the function should return + a hashable token; this token will be used as the key in a + dictionary in order to store and retrieve the current + :class:`.Session`. + + """ self.session_factory = session_factory + if scopefunc: self.registry = ScopedRegistry(session_factory, scopefunc) else: self.registry = ThreadLocalRegistry(session_factory) - self.extension = _ScopedExt(self) - def __call__(self, **kwargs): - if kwargs: - scope = kwargs.pop('scope', False) + def __call__(self, **kw): + r"""Return the current :class:`.Session`, creating it + using the :attr:`.scoped_session.session_factory` if not present. + + :param \**kw: Keyword arguments will be passed to the + :attr:`.scoped_session.session_factory` callable, if an existing + :class:`.Session` is not present. If the :class:`.Session` is present + and keyword arguments have been passed, + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if kw: + scope = kw.pop('scope', False) if scope is not None: if self.registry.has(): - raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.") + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified.") else: - sess = self.session_factory(**kwargs) + sess = self.session_factory(**kw) self.registry.set(sess) return sess else: - return self.session_factory(**kwargs) + return self.session_factory(**kw) else: return self.registry() def remove(self): - """Dispose of the current contextual session.""" - + """Dispose of the current :class:`.Session`, if present. + + This will first call :meth:`.Session.close` method + on the current :class:`.Session`, which releases any existing + transactional/connection resources still being held; transactions + specifically are rolled back. The :class:`.Session` is then + discarded. Upon next usage within the same scope, + the :class:`.scoped_session` will produce a new + :class:`.Session` object. + + """ + if self.registry.has(): self.registry().close() self.registry.clear() - @deprecated("Session.mapper is deprecated. " - "Please see http://www.sqlalchemy.org/trac/wiki/UsageRecipes/SessionAwareMapper " - "for information on how to replicate its behavior.") - def mapper(self, *args, **kwargs): - """return a mapper() function which associates this ScopedSession with the Mapper. + def configure(self, **kwargs): + """reconfigure the :class:`.sessionmaker` used by this + :class:`.scoped_session`. - DEPRECATED. + See :meth:`.sessionmaker.configure`. """ - from sqlalchemy.orm import mapper - - extension_args = dict((arg, kwargs.pop(arg)) - for arg in get_cls_kwargs(_ScopedExt) - if arg in kwargs) - - kwargs['extension'] = extension = to_list(kwargs.get('extension', [])) - if extension_args: - extension.append(self.extension.configure(**extension_args)) - else: - extension.append(self.extension) - return mapper(*args, **kwargs) - - def configure(self, **kwargs): - """reconfigure the sessionmaker used by this ScopedSession.""" + if self.registry.has(): + warn('At least one scoped session is already present. ' + ' configure() can not affect sessions that have ' + 'already been created.') self.session_factory.configure(**kwargs) def query_property(self, query_cls=None): - """return a class property which produces a `Query` object against the - class when called. + """return a class property which produces a :class:`.Query` object + against the class and the current :class:`.Session` when called. e.g.:: + Session = scoped_session(sessionmaker()) class MyClass(object): @@ -124,82 +148,37 @@ class ScopedSession(object): return None return query() +ScopedSession = scoped_session +"""Old name for backwards compatibility.""" + + def instrument(name): def do(self, *args, **kwargs): return getattr(self.registry(), name)(*args, **kwargs) return do + for meth in Session.public_methods: - setattr(ScopedSession, meth, instrument(meth)) + setattr(scoped_session, meth, instrument(meth)) + def makeprop(name): def set(self, attr): setattr(self.registry(), name, attr) + def get(self): return getattr(self.registry(), name) + return property(get, set) -for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', 'is_active', 'autoflush'): - setattr(ScopedSession, prop, makeprop(prop)) + +for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', + 'is_active', 'autoflush', 'no_autoflush', 'info'): + setattr(scoped_session, prop, makeprop(prop)) + def clslevel(name): def do(cls, *args, **kwargs): return getattr(Session, name)(*args, **kwargs) return classmethod(do) + for prop in ('close_all', 'object_session', 'identity_key'): - setattr(ScopedSession, prop, clslevel(prop)) - -class _ScopedExt(MapperExtension): - def __init__(self, context, validate=False, save_on_init=True): - self.context = context - self.validate = validate - self.save_on_init = save_on_init - self.set_kwargs_on_init = True - - def validating(self): - return _ScopedExt(self.context, validate=True) - - def configure(self, **kwargs): - return _ScopedExt(self.context, **kwargs) - - def instrument_class(self, mapper, class_): - class query(object): - def __getattr__(s, key): - return getattr(self.context.registry().query(class_), key) - def __call__(s): - return self.context.registry().query(class_) - def __get__(self, instance, cls): - return self - - if not 'query' in class_.__dict__: - class_.query = query() - - if self.set_kwargs_on_init and class_.__init__ is object.__init__: - class_.__init__ = self._default__init__(mapper) - - def _default__init__(ext, mapper): - def __init__(self, **kwargs): - for key, value in kwargs.iteritems(): - if ext.validate: - if not mapper.get_property(key, resolve_synonyms=False, - raiseerr=False): - raise sa_exc.ArgumentError( - "Invalid __init__ argument: '%s'" % key) - setattr(self, key, value) - return __init__ - - def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): - if self.save_on_init: - session = kwargs.pop('_sa_session', None) - if session is None: - session = self.context.registry() - session._save_without_cascade(instance) - return EXT_CONTINUE - - def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): - sess = object_session(instance) - if sess: - sess.expunge(instance) - return EXT_CONTINUE - - def dispose_class(self, mapper, class_): - if hasattr(class_, 'query'): - delattr(class_, 'query') + setattr(scoped_session, prop, clslevel(prop)) diff --git a/sqlalchemy/orm/session.py b/sqlalchemy/orm/session.py index 0a3fbe7..0819204 100644 --- a/sqlalchemy/orm/session.py +++ b/sqlalchemy/orm/session.py @@ -1,223 +1,214 @@ -# session.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# orm/session.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php - """Provides the Session class and related utilities.""" + import weakref -from itertools import chain -import sqlalchemy.exceptions as sa_exc -from sqlalchemy import util, sql, engine, log -from sqlalchemy.sql import util as sql_util, expression -from sqlalchemy.orm import ( - SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, state - ) -from sqlalchemy.orm.util import object_mapper as _object_mapper -from sqlalchemy.orm.util import class_mapper as _class_mapper -from sqlalchemy.orm.util import ( - _class_to_mapper, _state_has_identity, _state_mapper, - ) -from sqlalchemy.orm.mapper import Mapper, _none_set -from sqlalchemy.orm.unitofwork import UOWTransaction -from sqlalchemy.orm import identity +from .. import util, sql, engine, exc as sa_exc +from ..sql import util as sql_util, expression +from . import ( + SessionExtension, attributes, exc, query, + loading, identity +) +from ..inspection import inspect +from .base import ( + object_mapper, class_mapper, + _class_to_mapper, _state_mapper, object_state, + _none_set, state_str, instance_str +) +import itertools +from . import persistence +from .unitofwork import UOWTransaction +from . import state as statelib +import sys -__all__ = ['Session', 'SessionTransaction', 'SessionExtension'] +__all__ = ['Session', 'SessionTransaction', + 'SessionExtension', 'sessionmaker'] + +_sessions = weakref.WeakValueDictionary() +"""Weak-referencing dictionary of :class:`.Session` objects. +""" -def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, - expire_on_commit=True, **kwargs): - """Generate a custom-configured :class:`~sqlalchemy.orm.session.Session` class. - - The returned object is a subclass of ``Session``, which, when instantiated - with no arguments, uses the keyword arguments configured here as its - constructor arguments. - - It is intended that the `sessionmaker()` function be called within the - global scope of an application, and the returned class be made available - to the rest of the application as the single class used to instantiate - sessions. - - e.g.:: - - # global scope - Session = sessionmaker(autoflush=False) - - # later, in a local scope, create and use a session: - sess = Session() - - Any keyword arguments sent to the constructor itself will override the - "configured" keywords:: - - Session = sessionmaker() - - # bind an individual session to a connection - sess = Session(bind=connection) - - The class also includes a special classmethod ``configure()``, which - allows additional configurational options to take place after the custom - ``Session`` class has been generated. This is useful particularly for - defining the specific ``Engine`` (or engines) to which new instances of - ``Session`` should be bound:: - - Session = sessionmaker() - Session.configure(bind=create_engine('sqlite:///foo.db')) - - sess = Session() - - Options: - - autocommit - Defaults to ``False``. When ``True``, the ``Session`` does not keep a - persistent transaction running, and will acquire connections from the - engine on an as-needed basis, returning them immediately after their - use. Flushes will begin and commit (or possibly rollback) their own - transaction if no transaction is present. When using this mode, the - `session.begin()` method may be used to begin a transaction explicitly. - - Leaving it on its default value of ``False`` means that the ``Session`` - will acquire a connection and begin a transaction the first time it is - used, which it will maintain persistently until ``rollback()``, - ``commit()``, or ``close()`` is called. When the transaction is released - by any of these methods, the ``Session`` is ready for the next usage, - which will again acquire and maintain a new connection/transaction. - - autoflush - When ``True``, all query operations will issue a ``flush()`` call to - this ``Session`` before proceeding. This is a convenience feature so - that ``flush()`` need not be called repeatedly in order for database - queries to retrieve results. It's typical that ``autoflush`` is used in - conjunction with ``autocommit=False``. In this scenario, explicit calls - to ``flush()`` are rarely needed; you usually only need to call - ``commit()`` (which flushes) to finalize changes. - - bind - An optional ``Engine`` or ``Connection`` to which this ``Session`` - should be bound. When specified, all SQL operations performed by this - session will execute via this connectable. - - binds - An optional dictionary, which contains more granular "bind" information - than the ``bind`` parameter provides. This dictionary can map individual - ``Table`` instances as well as ``Mapper`` instances to individual - ``Engine`` or ``Connection`` objects. Operations which proceed relative - to a particular ``Mapper`` will consult this dictionary for the direct - ``Mapper`` instance as well as the mapper's ``mapped_table`` attribute - in order to locate an connectable to use. The full resolution is - described in the ``get_bind()`` method of ``Session``. Usage looks - like:: - - sess = Session(binds={ - SomeMappedClass: create_engine('postgresql://engine1'), - somemapper: create_engine('postgresql://engine2'), - some_table: create_engine('postgresql://engine3'), - }) - - Also see the ``bind_mapper()`` and ``bind_table()`` methods. - - \class_ - Specify an alternate class other than ``sqlalchemy.orm.session.Session`` - which should be used by the returned class. This is the only argument - that is local to the ``sessionmaker()`` function, and is not sent - directly to the constructor for ``Session``. - - _enable_transaction_accounting - Defaults to ``True``. A legacy-only flag which when ``False`` - disables *all* 0.5-style object accounting on transaction boundaries, - including auto-expiry of instances on rollback and commit, maintenance of - the "new" and "deleted" lists upon rollback, and autoflush - of pending changes upon begin(), all of which are interdependent. - - expire_on_commit - Defaults to ``True``. When ``True``, all instances will be fully expired after - each ``commit()``, so that all attribute/object access subsequent to a completed - transaction will load from the most recent database state. - - extension - An optional :class:`~sqlalchemy.orm.session.SessionExtension` instance, or - a list of such instances, which - will receive pre- and post- commit and flush events, as well as a - post-rollback event. User- defined code may be placed within these - hooks using a user-defined subclass of ``SessionExtension``. - - query_cls - Class which should be used to create new Query objects, as returned - by the ``query()`` method. Defaults to :class:`~sqlalchemy.orm.query.Query`. - - twophase - When ``True``, all transactions will be started using - :mod:~sqlalchemy.engine_TwoPhaseTransaction. During a ``commit()``, after - ``flush()`` has been issued for all attached databases, the - ``prepare()`` method on each database's ``TwoPhaseTransaction`` will be - called. This allows each database to roll back the entire transaction, - before each transaction is committed. - - weak_identity_map - When set to the default value of ``True``, a weak-referencing map is - used; instances which are not externally referenced will be garbage - collected immediately. For dereferenced instances which have pending - changes present, the attribute management system will create a temporary - strong-reference to the object which lasts until the changes are flushed - to the database, at which point it's again dereferenced. Alternatively, - when using the value ``False``, the identity map uses a regular Python - dictionary to store instances. The session will maintain all instances - present until they are removed using expunge(), clear(), or purge(). - +def _state_session(state): + """Given an :class:`.InstanceState`, return the :class:`.Session` + associated, if any. """ - kwargs['bind'] = bind - kwargs['autoflush'] = autoflush - kwargs['autocommit'] = autocommit - kwargs['expire_on_commit'] = expire_on_commit + if state.session_id: + try: + return _sessions[state.session_id] + except KeyError: + pass + return None - if class_ is None: - class_ = Session - class Sess(object): - def __init__(self, **local_kwargs): - for k in kwargs: - local_kwargs.setdefault(k, kwargs[k]) - super(Sess, self).__init__(**local_kwargs) +class _SessionClassMethods(object): + """Class-level methods for :class:`.Session`, :class:`.sessionmaker`.""" - def configure(self, **new_kwargs): - """(Re)configure the arguments for this sessionmaker. + @classmethod + def close_all(cls): + """Close *all* sessions in memory.""" - e.g.:: + for sess in _sessions.values(): + sess.close() - Session = sessionmaker() + @classmethod + @util.dependencies("sqlalchemy.orm.util") + def identity_key(cls, orm_util, *args, **kwargs): + """Return an identity key. - Session.configure(bind=create_engine('sqlite://')) - """ - kwargs.update(new_kwargs) - configure = classmethod(configure) - s = type.__new__(type, "Session", (Sess, class_), {}) - return s + This is an alias of :func:`.util.identity_key`. + + """ + return orm_util.identity_key(*args, **kwargs) + + @classmethod + def object_session(cls, instance): + """Return the :class:`.Session` to which an object belongs. + + This is an alias of :func:`.object_session`. + + """ + + return object_session(instance) + + +ACTIVE = util.symbol('ACTIVE') +PREPARED = util.symbol('PREPARED') +COMMITTED = util.symbol('COMMITTED') +DEACTIVE = util.symbol('DEACTIVE') +CLOSED = util.symbol('CLOSED') class SessionTransaction(object): - """A Session-level transaction. + """A :class:`.Session`-level transaction. - This corresponds to one or more :class:`~sqlalchemy.engine.Transaction` - instances behind the scenes, with one ``Transaction`` per ``Engine`` in - use. + :class:`.SessionTransaction` is a mostly behind-the-scenes object + not normally referenced directly by application code. It coordinates + among multiple :class:`.Connection` objects, maintaining a database + transaction for each one individually, committing or rolling them + back all at once. It also provides optional two-phase commit behavior + which can augment this coordination operation. - Direct usage of ``SessionTransaction`` is not necessary as of SQLAlchemy - 0.4; use the ``begin()`` and ``commit()`` methods on ``Session`` itself. + The :attr:`.Session.transaction` attribute of :class:`.Session` + refers to the current :class:`.SessionTransaction` object in use, if any. + The :attr:`.SessionTransaction.parent` attribute refers to the parent + :class:`.SessionTransaction` in the stack of :class:`.SessionTransaction` + objects. If this attribute is ``None``, then this is the top of the stack. + If non-``None``, then this :class:`.SessionTransaction` refers either + to a so-called "subtransaction" or a "nested" transaction. A + "subtransaction" is a scoping concept that demarcates an inner portion + of the outermost "real" transaction. A nested transaction, which + is indicated when the :attr:`.SessionTransaction.nested` + attribute is also True, indicates that this :class:`.SessionTransaction` + corresponds to a SAVEPOINT. - The ``SessionTransaction`` object is **not** thread-safe. + **Life Cycle** - .. index:: - single: thread safety; SessionTransaction + A :class:`.SessionTransaction` is associated with a :class:`.Session` + in its default mode of ``autocommit=False`` immediately, associated + with no database connections. As the :class:`.Session` is called upon + to emit SQL on behalf of various :class:`.Engine` or :class:`.Connection` + objects, a corresponding :class:`.Connection` and associated + :class:`.Transaction` is added to a collection within the + :class:`.SessionTransaction` object, becoming one of the + connection/transaction pairs maintained by the + :class:`.SessionTransaction`. The start of a :class:`.SessionTransaction` + can be tracked using the :meth:`.SessionEvents.after_transaction_create` + event. + + The lifespan of the :class:`.SessionTransaction` ends when the + :meth:`.Session.commit`, :meth:`.Session.rollback` or + :meth:`.Session.close` methods are called. At this point, the + :class:`.SessionTransaction` removes its association with its parent + :class:`.Session`. A :class:`.Session` that is in ``autocommit=False`` + mode will create a new :class:`.SessionTransaction` to replace it + immediately, whereas a :class:`.Session` that's in ``autocommit=True`` + mode will remain without a :class:`.SessionTransaction` until the + :meth:`.Session.begin` method is called. The end of a + :class:`.SessionTransaction` can be tracked using the + :meth:`.SessionEvents.after_transaction_end` event. + + **Nesting and Subtransactions** + + Another detail of :class:`.SessionTransaction` behavior is that it is + capable of "nesting". This means that the :meth:`.Session.begin` method + can be called while an existing :class:`.SessionTransaction` is already + present, producing a new :class:`.SessionTransaction` that temporarily + replaces the parent :class:`.SessionTransaction`. When a + :class:`.SessionTransaction` is produced as nested, it assigns itself to + the :attr:`.Session.transaction` attribute, and it additionally will assign + the previous :class:`.SessionTransaction` to its :attr:`.Session.parent` + attribute. The behavior is effectively a + stack, where :attr:`.Session.transaction` refers to the current head of + the stack, and the :attr:`.SessionTransaction.parent` attribute allows + traversal up the stack until :attr:`.SessionTransaction.parent` is + ``None``, indicating the top of the stack. + + When the scope of :class:`.SessionTransaction` is ended via + :meth:`.Session.commit` or :meth:`.Session.rollback`, it restores its + parent :class:`.SessionTransaction` back onto the + :attr:`.Session.transaction` attribute. + + The purpose of this stack is to allow nesting of + :meth:`.Session.rollback` or :meth:`.Session.commit` calls in context + with various flavors of :meth:`.Session.begin`. This nesting behavior + applies to when :meth:`.Session.begin_nested` is used to emit a + SAVEPOINT transaction, and is also used to produce a so-called + "subtransaction" which allows a block of code to use a + begin/rollback/commit sequence regardless of whether or not its enclosing + code block has begun a transaction. The :meth:`.flush` method, whether + called explicitly or via autoflush, is the primary consumer of the + "subtransaction" feature, in that it wishes to guarantee that it works + within in a transaction block regardless of whether or not the + :class:`.Session` is in transactional mode when the method is called. + + Note that the flush process that occurs within the "autoflush" feature + as well as when the :meth:`.Session.flush` method is used **always** + creates a :class:`.SessionTransaction` object. This object is normally + a subtransaction, unless the :class:`.Session` is in autocommit mode + and no transaction exists at all, in which case it's the outermost + transaction. Any event-handling logic or other inspection logic + needs to take into account whether a :class:`.SessionTransaction` + is the outermost transaction, a subtransaction, or a "nested" / SAVEPOINT + transaction. + + .. seealso:: + + :meth:`.Session.rollback` + + :meth:`.Session.commit` + + :meth:`.Session.begin` + + :meth:`.Session.begin_nested` + + :attr:`.Session.is_active` + + :meth:`.SessionEvents.after_transaction_create` + + :meth:`.SessionEvents.after_transaction_end` + + :meth:`.SessionEvents.after_commit` + + :meth:`.SessionEvents.after_rollback` + + :meth:`.SessionEvents.after_soft_rollback` """ + _rollback_exception = None + def __init__(self, session, parent=None, nested=False): self.session = session self._connections = {} self._parent = parent self.nested = nested - self._active = True - self._prepared = False + self._state = ACTIVE if not parent and nested: raise sa_exc.InvalidRequestError( "Can't start a SAVEPOINT transaction when no existing " @@ -226,49 +217,110 @@ class SessionTransaction(object): if self.session._enable_transaction_accounting: self._take_snapshot() + self.session.dispatch.after_transaction_create(self.session, self) + + @property + def parent(self): + """The parent :class:`.SessionTransaction` of this + :class:`.SessionTransaction`. + + If this attribute is ``None``, indicates this + :class:`.SessionTransaction` is at the top of the stack, and + corresponds to a real "COMMIT"/"ROLLBACK" + block. If non-``None``, then this is either a "subtransaction" + or a "nested" / SAVEPOINT transaction. If the + :attr:`.SessionTransaction.nested` attribute is ``True``, then + this is a SAVEPOINT, and if ``False``, indicates this a subtransaction. + + .. versionadded:: 1.0.16 - use ._parent for previous versions + + """ + return self._parent + + nested = False + """Indicates if this is a nested, or SAVEPOINT, transaction. + + When :attr:`.SessionTransaction.nested` is True, it is expected + that :attr:`.SessionTransaction.parent` will be True as well. + + """ + @property def is_active(self): - return self.session is not None and self._active + return self.session is not None and self._state is ACTIVE - def _assert_is_active(self): - self._assert_is_open() - if not self._active: + def _assert_active(self, prepared_ok=False, + rollback_ok=False, + deactive_ok=False, + closed_msg="This transaction is closed"): + if self._state is COMMITTED: raise sa_exc.InvalidRequestError( - "The transaction is inactive due to a rollback in a " - "subtransaction. Issue rollback() to cancel the transaction.") - - def _assert_is_open(self, error_msg="The transaction is closed"): - if self.session is None: - raise sa_exc.InvalidRequestError(error_msg) + "This session is in 'committed' state; no further " + "SQL can be emitted within this transaction." + ) + elif self._state is PREPARED: + if not prepared_ok: + raise sa_exc.InvalidRequestError( + "This session is in 'prepared' state; no further " + "SQL can be emitted within this transaction." + ) + elif self._state is DEACTIVE: + if not deactive_ok and not rollback_ok: + if self._rollback_exception: + raise sa_exc.InvalidRequestError( + "This Session's transaction has been rolled back " + "due to a previous exception during flush." + " To begin a new transaction with this Session, " + "first issue Session.rollback()." + " Original exception was: %s" + % self._rollback_exception + ) + elif not deactive_ok: + raise sa_exc.InvalidRequestError( + "This Session's transaction has been rolled back " + "by a nested rollback() call. To begin a new " + "transaction, issue Session.rollback() first." + ) + elif self._state is CLOSED: + raise sa_exc.ResourceClosedError(closed_msg) @property def _is_transaction_boundary(self): return self.nested or not self._parent - def connection(self, bindkey, **kwargs): - self._assert_is_active() - engine = self.session.get_bind(bindkey, **kwargs) - return self._connection_for_bind(engine) + def connection(self, bindkey, execution_options=None, **kwargs): + self._assert_active() + bind = self.session.get_bind(bindkey, **kwargs) + return self._connection_for_bind(bind, execution_options) def _begin(self, nested=False): - self._assert_is_active() + self._assert_active() return SessionTransaction( self.session, self, nested=nested) - def _iterate_parents(self, upto=None): - if self._parent is upto: - return (self,) - else: - if self._parent is None: + def _iterate_self_and_parents(self, upto=None): + + current = self + result = () + while current: + result += (current, ) + if current._parent is upto: + break + elif current._parent is None: raise sa_exc.InvalidRequestError( "Transaction %s is not on the active transaction list" % ( - upto)) - return (self,) + self._parent._iterate_parents(upto) + upto)) + else: + current = current._parent + + return result def _take_snapshot(self): if not self._is_transaction_boundary: self._new = self._parent._new self._deleted = self._parent._deleted + self._dirty = self._parent._dirty + self._key_switches = self._parent._key_switches return if not self.session._flushing: @@ -276,36 +328,68 @@ class SessionTransaction(object): self._new = weakref.WeakKeyDictionary() self._deleted = weakref.WeakKeyDictionary() + self._dirty = weakref.WeakKeyDictionary() + self._key_switches = weakref.WeakKeyDictionary() - def _restore_snapshot(self): + def _restore_snapshot(self, dirty_only=False): + """Restore the restoration state taken before a transaction began. + + Corresponds to a rollback. + + """ assert self._is_transaction_boundary - for s in set(self._new).union(self.session._new): - self.session._expunge_state(s) + self.session._expunge_states( + set(self._new).union(self.session._new), + to_transient=True) + + for s, (oldkey, newkey) in self._key_switches.items(): + self.session.identity_map.safe_discard(s) + s.key = oldkey + self.session.identity_map.replace(s) for s in set(self._deleted).union(self.session._deleted): - self.session._update_impl(s) + self.session._update_impl(s, revert_deletion=True) assert not self.session._deleted for s in self.session.identity_map.all_states(): - _expire_state(s, s.dict, None, instance_dict=self.session.identity_map) + if not dirty_only or s.modified or s in self._dirty: + s._expire(s.dict, self.session.identity_map._modified) def _remove_snapshot(self): + """Remove the restoration state taken before a transaction began. + + Corresponds to a commit. + + """ assert self._is_transaction_boundary if not self.nested and self.session.expire_on_commit: for s in self.session.identity_map.all_states(): - _expire_state(s, s.dict, None, instance_dict=self.session.identity_map) + s._expire(s.dict, self.session.identity_map._modified) - def _connection_for_bind(self, bind): - self._assert_is_active() + statelib.InstanceState._detach_states( + list(self._deleted), self.session) + self._deleted.clear() + elif self.nested: + self._parent._new.update(self._new) + self._parent._dirty.update(self._dirty) + self._parent._deleted.update(self._deleted) + self._parent._key_switches.update(self._key_switches) + + def _connection_for_bind(self, bind, execution_options): + self._assert_active() if bind in self._connections: + if execution_options: + util.warn( + "Connection is already established for the " + "given bind; execution_options ignored") return self._connections[bind][0] if self._parent: - conn = self._parent._connection_for_bind(bind) + conn = self._parent._connection_for_bind(bind, execution_options) if not self.nested: return conn else: @@ -318,6 +402,9 @@ class SessionTransaction(object): else: conn = bind.contextual_connect() + if execution_options: + conn = conn.execution_options(**execution_options) + if self.session.twophase and self._parent is None: transaction = conn.begin_twophase() elif self.nested: @@ -326,53 +413,59 @@ class SessionTransaction(object): transaction = conn.begin() self._connections[conn] = self._connections[conn.engine] = \ - (conn, transaction, conn is not bind) - for ext in self.session.extensions: - ext.after_begin(self.session, self, conn) + (conn, transaction, conn is not bind) + self.session.dispatch.after_begin(self.session, self, conn) return conn def prepare(self): if self._parent is not None or not self.session.twophase: raise sa_exc.InvalidRequestError( - "Only root two phase transactions of can be prepared") + "'twophase' mode not enabled, or not root transaction; " + "can't prepare.") self._prepare_impl() def _prepare_impl(self): - self._assert_is_active() + self._assert_active() if self._parent is None or self.nested: - for ext in self.session.extensions: - ext.before_commit(self.session) + self.session.dispatch.before_commit(self.session) stx = self.session.transaction if stx is not self: - for subtransaction in stx._iterate_parents(upto=self): + for subtransaction in stx._iterate_self_and_parents(upto=self): subtransaction.commit() if not self.session._flushing: - self.session.flush() + for _flush_guard in range(100): + if self.session._is_clean(): + break + self.session.flush() + else: + raise exc.FlushError( + "Over 100 subsequent flushes have occurred within " + "session.commit() - is an after_flush() hook " + "creating new objects?") if self._parent is None and self.session.twophase: try: for t in set(self._connections.values()): t[1].prepare() except: - self.rollback() - raise + with util.safe_reraise(): + self.rollback() - self._deactivate() - self._prepared = True + self._state = PREPARED def commit(self): - self._assert_is_open() - if not self._prepared: + self._assert_active(prepared_ok=True) + if self._state is not PREPARED: self._prepare_impl() if self._parent is None or self.nested: for t in set(self._connections.values()): t[1].commit() - for ext in self.session.extensions: - ext.after_commit(self.session) + self._state = COMMITTED + self.session.dispatch.after_commit(self.session) if self.session._enable_transaction_accounting: self._remove_snapshot() @@ -380,50 +473,82 @@ class SessionTransaction(object): self.close() return self._parent - def rollback(self): - self._assert_is_open() + def rollback(self, _capture_exception=False): + self._assert_active(prepared_ok=True, rollback_ok=True) stx = self.session.transaction if stx is not self: - for subtransaction in stx._iterate_parents(upto=self): + for subtransaction in stx._iterate_self_and_parents(upto=self): subtransaction.close() - if self.is_active or self._prepared: - for transaction in self._iterate_parents(): + boundary = self + rollback_err = None + if self._state in (ACTIVE, PREPARED): + for transaction in self._iterate_self_and_parents(): if transaction._parent is None or transaction.nested: - transaction._rollback_impl() - transaction._deactivate() + try: + transaction._rollback_impl() + except: + rollback_err = sys.exc_info() + transaction._state = DEACTIVE + boundary = transaction break else: - transaction._deactivate() + transaction._state = DEACTIVE + + sess = self.session + + if not rollback_err and sess._enable_transaction_accounting and \ + not sess._is_clean(): + + # if items were added, deleted, or mutated + # here, we need to re-restore the snapshot + util.warn( + "Session's state has been changed on " + "a non-active transaction - this state " + "will be discarded.") + boundary._restore_snapshot(dirty_only=boundary.nested) self.close() + + if self._parent and _capture_exception: + self._parent._rollback_exception = sys.exc_info()[1] + + if rollback_err: + util.reraise(*rollback_err) + + sess.dispatch.after_soft_rollback(sess, self) + return self._parent def _rollback_impl(self): - for t in set(self._connections.values()): - t[1].rollback() + try: + for t in set(self._connections.values()): + t[1].rollback() + finally: + if self.session._enable_transaction_accounting: + self._restore_snapshot(dirty_only=self.nested) - if self.session._enable_transaction_accounting: - self._restore_snapshot() + self.session.dispatch.after_rollback(self.session) - for ext in self.session.extensions: - ext.after_rollback(self.session) - - def _deactivate(self): - self._active = False - - def close(self): + def close(self, invalidate=False): self.session.transaction = self._parent if self._parent is None: - for connection, transaction, autoclose in set(self._connections.values()): + for connection, transaction, autoclose in \ + set(self._connections.values()): + if invalidate: + connection.invalidate() if autoclose: connection.close() else: transaction.close() + + self._state = CLOSED + self.session.dispatch.after_transaction_end(self.session, self) + + if self._parent is None: if not self.session.autocommit: self.session.begin() - self._deactivate() self.session = None self._connections = None @@ -431,109 +556,173 @@ class SessionTransaction(object): return self def __exit__(self, type, value, traceback): - self._assert_is_open("Cannot end transaction context. The transaction was closed from within the context") + self._assert_active(deactive_ok=True, prepared_ok=True) if self.session.transaction is None: return if type is None: try: self.commit() except: - self.rollback() - raise + with util.safe_reraise(): + self.rollback() else: self.rollback() -class Session(object): + +class Session(_SessionClassMethods): """Manages persistence operations for ORM-mapped objects. - The Session is the front end to SQLAlchemy's **Unit of Work** - implementation. The concept behind Unit of Work is to track modifications - to a field of objects, and then be able to flush those changes to the - database in a single operation. + The Session's usage paradigm is described at :doc:`/orm/session`. - SQLAlchemy's unit of work includes these functions: - - * The ability to track in-memory changes on scalar- and collection-based - object attributes, such that database persistence operations can be - assembled based on those changes. - - * The ability to organize individual SQL queries and population of newly - generated primary and foreign key-holding attributes during a persist - operation such that referential integrity is maintained at all times. - - * The ability to maintain insert ordering against the order in which new - instances were added to the session. - - * An Identity Map, which is a dictionary keying instances to their unique - primary key identity. This ensures that only one copy of a particular - entity is ever present within the session, even if repeated load - operations for the same entity occur. This allows many parts of an - application to get a handle to a particular object without any chance of - modifications going to two different places. - - When dealing with instances of mapped classes, an instance may be - *attached* to a particular Session, else it is *unattached* . An instance - also may or may not correspond to an actual row in the database. These - conditions break up into four distinct states: - - * *Transient* - an instance that's not in a session, and is not saved to - the database; i.e. it has no database identity. The only relationship - such an object has to the ORM is that its class has a ``mapper()`` - associated with it. - - * *Pending* - when you ``add()`` a transient instance, it becomes - pending. It still wasn't actually flushed to the database yet, but it - will be when the next flush occurs. - - * *Persistent* - An instance which is present in the session and has a - record in the database. You get persistent instances by either flushing - so that the pending instances become persistent, or by querying the - database for existing instances (or moving persistent instances from - other sessions into your local session). - - * *Detached* - an instance which has a record in the database, but is not - in any session. Theres nothing wrong with this, and you can use objects - normally when they're detached, **except** they will not be able to - issue any SQL in order to load collections or attributes which are not - yet loaded, or were marked as "expired". - - The session methods which control instance state include ``add()``, - ``delete()``, ``merge()``, and ``expunge()``. - - The Session object is generally **not** threadsafe. A session which is - set to ``autocommit`` and is only read from may be used by concurrent - threads if it's acceptable that some object instances may be loaded twice. - - The typical pattern to managing Sessions in a multi-threaded environment - is either to use mutexes to limit concurrent access to one thread at a - time, or more commonly to establish a unique session for every thread, - using a threadlocal variable. SQLAlchemy provides a thread-managed - Session adapter, provided by the :func:`~sqlalchemy.orm.scoped_session` - function. """ public_methods = ( '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested', 'close', 'commit', 'connection', 'delete', 'execute', 'expire', - 'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind', 'is_modified', - 'merge', 'query', 'refresh', 'rollback', + 'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind', + 'is_modified', 'bulk_save_objects', 'bulk_insert_mappings', + 'bulk_update_mappings', + 'merge', 'query', 'refresh', 'rollback', 'scalar') def __init__(self, bind=None, autoflush=True, expire_on_commit=True, - _enable_transaction_accounting=True, - autocommit=False, twophase=False, - weak_identity_map=True, binds=None, extension=None, query_cls=query.Query): - """Construct a new Session. + _enable_transaction_accounting=True, + autocommit=False, twophase=False, + weak_identity_map=True, binds=None, extension=None, + info=None, + query_cls=query.Query): + r"""Construct a new Session. - Arguments to ``Session`` are described using the - :func:`~sqlalchemy.orm.sessionmaker` function. + See also the :class:`.sessionmaker` function which is used to + generate a :class:`.Session`-producing callable with a given + set of arguments. + + :param autocommit: + + .. warning:: + + The autocommit flag is **not for general use**, and if it is + used, queries should only be invoked within the span of a + :meth:`.Session.begin` / :meth:`.Session.commit` pair. Executing + queries outside of a demarcated transaction is a legacy mode + of usage, and can in some cases lead to concurrent connection + checkouts. + + Defaults to ``False``. When ``True``, the + :class:`.Session` does not keep a persistent transaction running, + and will acquire connections from the engine on an as-needed basis, + returning them immediately after their use. Flushes will begin and + commit (or possibly rollback) their own transaction if no + transaction is present. When using this mode, the + :meth:`.Session.begin` method is used to explicitly start + transactions. + + .. seealso:: + + :ref:`session_autocommit` + + :param autoflush: When ``True``, all query operations will issue a + :meth:`~.Session.flush` call to this ``Session`` before proceeding. + This is a convenience feature so that :meth:`~.Session.flush` need + not be called repeatedly in order for database queries to retrieve + results. It's typical that ``autoflush`` is used in conjunction + with ``autocommit=False``. In this scenario, explicit calls to + :meth:`~.Session.flush` are rarely needed; you usually only need to + call :meth:`~.Session.commit` (which flushes) to finalize changes. + + :param bind: An optional :class:`.Engine` or :class:`.Connection` to + which this ``Session`` should be bound. When specified, all SQL + operations performed by this session will execute via this + connectable. + + :param binds: An optional dictionary which contains more granular + "bind" information than the ``bind`` parameter provides. This + dictionary can map individual :class`.Table` + instances as well as :class:`~.Mapper` instances to individual + :class:`.Engine` or :class:`.Connection` objects. Operations which + proceed relative to a particular :class:`.Mapper` will consult this + dictionary for the direct :class:`.Mapper` instance as + well as the mapper's ``mapped_table`` attribute in order to locate + a connectable to use. The full resolution is described in the + :meth:`.Session.get_bind`. + Usage looks like:: + + Session = sessionmaker(binds={ + SomeMappedClass: create_engine('postgresql://engine1'), + somemapper: create_engine('postgresql://engine2'), + some_table: create_engine('postgresql://engine3'), + }) + + Also see the :meth:`.Session.bind_mapper` + and :meth:`.Session.bind_table` methods. + + :param \class_: Specify an alternate class other than + ``sqlalchemy.orm.session.Session`` which should be used by the + returned class. This is the only argument that is local to the + :class:`.sessionmaker` function, and is not sent directly to the + constructor for ``Session``. + + :param _enable_transaction_accounting: Defaults to ``True``. A + legacy-only flag which when ``False`` disables *all* 0.5-style + object accounting on transaction boundaries, including auto-expiry + of instances on rollback and commit, maintenance of the "new" and + "deleted" lists upon rollback, and autoflush of pending changes + upon :meth:`~.Session.begin`, all of which are interdependent. + + :param expire_on_commit: Defaults to ``True``. When ``True``, all + instances will be fully expired after each :meth:`~.commit`, + so that all attribute/object access subsequent to a completed + transaction will load from the most recent database state. + + :param extension: An optional + :class:`~.SessionExtension` instance, or a list + of such instances, which will receive pre- and post- commit and + flush events, as well as a post-rollback event. **Deprecated.** + Please see :class:`.SessionEvents`. + + :param info: optional dictionary of arbitrary data to be associated + with this :class:`.Session`. Is available via the + :attr:`.Session.info` attribute. Note the dictionary is copied at + construction time so that modifications to the per- + :class:`.Session` dictionary will be local to that + :class:`.Session`. + + .. versionadded:: 0.9.0 + + :param query_cls: Class which should be used to create new Query + objects, as returned by the :meth:`~.Session.query` method. + Defaults to :class:`.Query`. + + :param twophase: When ``True``, all transactions will be started as + a "two phase" transaction, i.e. using the "two phase" semantics + of the database in use along with an XID. During a + :meth:`~.commit`, after :meth:`~.flush` has been issued for all + attached databases, the :meth:`~.TwoPhaseTransaction.prepare` + method on each database's :class:`.TwoPhaseTransaction` will be + called. This allows each database to roll back the entire + transaction, before each transaction is committed. + + :param weak_identity_map: Defaults to ``True`` - when set to + ``False``, objects placed in the :class:`.Session` will be + strongly referenced until explicitly removed or the + :class:`.Session` is closed. **Deprecated** - The strong + reference identity map is legacy. See the + recipe at :ref:`session_referencing_behavior` for + an event-based approach to maintaining strong identity + references. """ - + if weak_identity_map: self._identity_cls = identity.WeakInstanceDict else: + util.warn_deprecated( + "weak_identity_map=False is deprecated. " + "See the documentation on 'Session Referencing Behavior' " + "for an event-based approach to maintaining strong identity " + "references.") + self._identity_cls = identity.StrongInstanceDict self.identity_map = self._identity_cls() @@ -542,49 +731,65 @@ class Session(object): self.bind = bind self.__binds = {} self._flushing = False + self._warn_on_events = False self.transaction = None - self.hash_key = id(self) + self.hash_key = _new_sessionid() self.autoflush = autoflush self.autocommit = autocommit self.expire_on_commit = expire_on_commit self._enable_transaction_accounting = _enable_transaction_accounting self.twophase = twophase - self.extensions = util.to_list(extension) or [] self._query_cls = query_cls - self._mapper_flush_opts = {} + if info: + self.info.update(info) + + if extension: + for ext in util.to_list(extension): + SessionExtension._adapt_listener(self, ext) if binds is not None: - for mapperortable, bind in binds.iteritems(): - if isinstance(mapperortable, (type, Mapper)): - self.bind_mapper(mapperortable, bind) - else: - self.bind_table(mapperortable, bind) + for key, bind in binds.items(): + self._add_bind(key, bind) if not self.autocommit: self.begin() _sessions[self.hash_key] = self + connection_callable = None + + transaction = None + """The current active or inactive :class:`.SessionTransaction`.""" + + @util.memoized_property + def info(self): + """A user-modifiable dictionary. + + The initial value of this dictionary can be populated using the + ``info`` argument to the :class:`.Session` constructor or + :class:`.sessionmaker` constructor or factory methods. The dictionary + here is always local to this :class:`.Session` and can be modified + independently of all other :class:`.Session` objects. + + .. versionadded:: 0.9.0 + + """ + return {} + def begin(self, subtransactions=False, nested=False): - """Begin a transaction on this Session. + """Begin a transaction on this :class:`.Session`. If this Session is already within a transaction, either a plain transaction or nested transaction, an error is raised, unless ``subtransactions=True`` or ``nested=True`` is specified. - The ``subtransactions=True`` flag indicates that this ``begin()`` can - create a subtransaction if a transaction is already in progress. A - subtransaction is a non-transactional, delimiting construct that - allows matching begin()/commit() pairs to be nested together, with - only the outermost begin/commit pair actually affecting transactional - state. When a rollback is issued, the subtransaction will directly - roll back the innermost real transaction, however each subtransaction - still must be explicitly rolled back to maintain proper stacking of - subtransactions. - - If no transaction is in progress, then a real transaction is begun. + The ``subtransactions=True`` flag indicates that this + :meth:`~.Session.begin` can create a subtransaction if a transaction + is already in progress. For documentation on subtransactions, please + see :ref:`session_subtransactions`. The ``nested`` flag begins a SAVEPOINT transaction and is equivalent - to calling ``begin_nested()``. + to calling :meth:`~.Session.begin_nested`. For documentation on + SAVEPOINT transactions, please see :ref:`session_begin_nested`. """ if self.transaction is not None: @@ -593,8 +798,8 @@ class Session(object): nested=nested) else: raise sa_exc.InvalidRequestError( - "A transaction is already begun. Use subtransactions=True " - "to allow subtransactions.") + "A transaction is already begun. Use " + "subtransactions=True to allow subtransactions.") else: self.transaction = SessionTransaction( self, nested=nested) @@ -606,10 +811,8 @@ class Session(object): The target database(s) must support SQL SAVEPOINTs or a SQLAlchemy-supported vendor implementation of the idea. - The nested transaction is a real transation, unlike a "subtransaction" - which corresponds to multiple ``begin()`` calls. The next - ``rollback()`` or ``commit()`` call will operate upon this nested - transaction. + For documentation on SAVEPOINT + transactions, please see :ref:`session_begin_nested`. """ return self.begin(nested=True) @@ -622,7 +825,11 @@ class Session(object): This method rolls back the current transaction or nested transaction regardless of subtransactions being in effect. All subtransactions up to the first real transaction are closed. Subtransactions occur when - begin() is called multiple times. + :meth:`.begin` is called multiple times. + + .. seealso:: + + :ref:`session_rollback` """ if self.transaction is None: @@ -634,17 +841,29 @@ class Session(object): """Flush pending changes and commit the current transaction. If no transaction is in progress, this method raises an - InvalidRequestError. + :exc:`~sqlalchemy.exc.InvalidRequestError`. + + By default, the :class:`.Session` also expires all database + loaded state on all ORM-managed attributes after transaction commit. + This so that subsequent operations load the most recent + data from the database. This behavior can be disabled using + the ``expire_on_commit=False`` option to :class:`.sessionmaker` or + the :class:`.Session` constructor. If a subtransaction is in effect (which occurs when begin() is called multiple times), the subtransaction will be closed, and the next call to ``commit()`` will operate on the enclosing transaction. - For a session configured with autocommit=False, a new transaction will + When using the :class:`.Session` in its default mode of + ``autocommit=False``, a new transaction will be begun immediately after the commit, but note that the newly begun transaction does *not* use any connection resources until the first SQL is actually emitted. + .. seealso:: + + :ref:`session_committing` + """ if self.transaction is None: if not self.autocommit: @@ -658,10 +877,11 @@ class Session(object): """Prepare the current transaction in progress for two phase commit. If no transaction is in progress, this method raises an - InvalidRequestError. + :exc:`~sqlalchemy.exc.InvalidRequestError`. Only root transactions of two phase sessions can be prepared. If the - current transaction is not such, an InvalidRequestError is raised. + current transaction is not such, an + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. """ if self.transaction is None: @@ -672,74 +892,225 @@ class Session(object): self.transaction.prepare() - def connection(self, mapper=None, clause=None): - """Return the active Connection. + def connection(self, mapper=None, clause=None, + bind=None, + close_with_result=False, + execution_options=None, + **kw): + r"""Return a :class:`.Connection` object corresponding to this + :class:`.Session` object's transactional state. - Retrieves the ``Connection`` managing the current transaction. Any - operations executed on the Connection will take place in the same - transactional context as ``Session`` operations. + If this :class:`.Session` is configured with ``autocommit=False``, + either the :class:`.Connection` corresponding to the current + transaction is returned, or if no transaction is in progress, a new + one is begun and the :class:`.Connection` returned (note that no + transactional state is established with the DBAPI until the first + SQL statement is emitted). - For ``autocommit`` Sessions with no active manual transaction, - ``connection()`` is a passthrough to ``contextual_connect()`` on the - underlying engine. + Alternatively, if this :class:`.Session` is configured with + ``autocommit=True``, an ad-hoc :class:`.Connection` is returned + using :meth:`.Engine.contextual_connect` on the underlying + :class:`.Engine`. - Ambiguity in multi-bind or unbound Sessions can be resolved through - any of the optional keyword arguments. See ``get_bind()`` for more - information. + Ambiguity in multi-bind or unbound :class:`.Session` objects can be + resolved through any of the optional keyword arguments. This + ultimately makes usage of the :meth:`.get_bind` method for resolution. - mapper - Optional, a ``mapper`` or mapped class + :param bind: + Optional :class:`.Engine` to be used as the bind. If + this engine is already involved in an ongoing transaction, + that connection will be used. This argument takes precedence + over ``mapper``, ``clause``. - clause - Optional, any ``ClauseElement`` + :param mapper: + Optional :func:`.mapper` mapped class, used to identify + the appropriate bind. This argument takes precedence over + ``clause``. + + :param clause: + A :class:`.ClauseElement` (i.e. :func:`~.sql.expression.select`, + :func:`~.sql.expression.text`, + etc.) which will be used to locate a bind, if a bind + cannot otherwise be identified. + + :param close_with_result: Passed to :meth:`.Engine.connect`, + indicating the :class:`.Connection` should be considered + "single use", automatically closing when the first result set is + closed. This flag only has an effect if this :class:`.Session` is + configured with ``autocommit=True`` and does not already have a + transaction in progress. + + :param execution_options: a dictionary of execution options that will + be passed to :meth:`.Connection.execution_options`, **when the + connection is first procured only**. If the connection is already + present within the :class:`.Session`, a warning is emitted and + the arguments are ignored. + + .. versionadded:: 0.9.9 + + .. seealso:: + + :ref:`session_transaction_isolation` + + :param \**kw: + Additional keyword arguments are sent to :meth:`get_bind()`, + allowing additional arguments to be passed to custom + implementations of :meth:`get_bind`. """ - return self._connection_for_bind(self.get_bind(mapper, clause)) + if bind is None: + bind = self.get_bind(mapper, clause=clause, **kw) - def _connection_for_bind(self, engine, **kwargs): + return self._connection_for_bind(bind, + close_with_result=close_with_result, + execution_options=execution_options) + + def _connection_for_bind(self, engine, execution_options=None, **kw): if self.transaction is not None: - return self.transaction._connection_for_bind(engine) + return self.transaction._connection_for_bind( + engine, execution_options) else: - return engine.contextual_connect(**kwargs) + conn = engine.contextual_connect(**kw) + if execution_options: + conn = conn.execution_options(**execution_options) + return conn - def execute(self, clause, params=None, mapper=None, **kw): - """Execute a clause within the current transaction. + def execute(self, clause, params=None, mapper=None, bind=None, **kw): + r"""Execute a SQL expression construct or string statement within + the current transaction. - Returns a ``ResultProxy`` of execution results. `autocommit` Sessions - will create a transaction on the fly. + Returns a :class:`.ResultProxy` representing + results of the statement execution, in the same manner as that of an + :class:`.Engine` or + :class:`.Connection`. - Connection ambiguity in multi-bind or unbound Sessions will be - resolved by inspecting the clause for binds. The 'mapper' and - 'instance' keyword arguments may be used if this is insufficient, See - ``get_bind()`` for more information. + E.g.:: - clause - A ClauseElement (i.e. select(), text(), etc.) or - string SQL statement to be executed + result = session.execute( + user_table.select().where(user_table.c.id == 5) + ) - params - Optional, a dictionary of bind parameters. + :meth:`~.Session.execute` accepts any executable clause construct, + such as :func:`~.sql.expression.select`, + :func:`~.sql.expression.insert`, + :func:`~.sql.expression.update`, + :func:`~.sql.expression.delete`, and + :func:`~.sql.expression.text`. Plain SQL strings can be passed + as well, which in the case of :meth:`.Session.execute` only + will be interpreted the same as if it were passed via a + :func:`~.expression.text` construct. That is, the following usage:: - mapper - Optional, a ``mapper`` or mapped class + result = session.execute( + "SELECT * FROM user WHERE id=:param", + {"param":5} + ) + + is equivalent to:: + + from sqlalchemy import text + result = session.execute( + text("SELECT * FROM user WHERE id=:param"), + {"param":5} + ) + + The second positional argument to :meth:`.Session.execute` is an + optional parameter set. Similar to that of + :meth:`.Connection.execute`, whether this is passed as a single + dictionary, or a list of dictionaries, determines whether the DBAPI + cursor's ``execute()`` or ``executemany()`` is used to execute the + statement. An INSERT construct may be invoked for a single row:: + + result = session.execute( + users.insert(), {"id": 7, "name": "somename"}) + + or for multiple rows:: + + result = session.execute(users.insert(), [ + {"id": 7, "name": "somename7"}, + {"id": 8, "name": "somename8"}, + {"id": 9, "name": "somename9"} + ]) + + The statement is executed within the current transactional context of + this :class:`.Session`. The :class:`.Connection` which is used + to execute the statement can also be acquired directly by + calling the :meth:`.Session.connection` method. Both methods use + a rule-based resolution scheme in order to determine the + :class:`.Connection`, which in the average case is derived directly + from the "bind" of the :class:`.Session` itself, and in other cases + can be based on the :func:`.mapper` + and :class:`.Table` objects passed to the method; see the + documentation for :meth:`.Session.get_bind` for a full description of + this scheme. + + The :meth:`.Session.execute` method does *not* invoke autoflush. + + The :class:`.ResultProxy` returned by the :meth:`.Session.execute` + method is returned with the "close_with_result" flag set to true; + the significance of this flag is that if this :class:`.Session` is + autocommitting and does not have a transaction-dedicated + :class:`.Connection` available, a temporary :class:`.Connection` is + established for the statement execution, which is closed (meaning, + returned to the connection pool) when the :class:`.ResultProxy` has + consumed all available data. This applies *only* when the + :class:`.Session` is configured with autocommit=True and no + transaction has been started. + + :param clause: + An executable statement (i.e. an :class:`.Executable` expression + such as :func:`.expression.select`) or string SQL statement + to be executed. + + :param params: + Optional dictionary, or list of dictionaries, containing + bound parameter values. If a single dictionary, single-row + execution occurs; if a list of dictionaries, an + "executemany" will be invoked. The keys in each dictionary + must correspond to parameter names present in the statement. + + :param mapper: + Optional :func:`.mapper` or mapped class, used to identify + the appropriate bind. This argument takes precedence over + ``clause`` when locating a bind. See :meth:`.Session.get_bind` + for more details. + + :param bind: + Optional :class:`.Engine` to be used as the bind. If + this engine is already involved in an ongoing transaction, + that connection will be used. This argument takes + precedence over ``mapper`` and ``clause`` when locating + a bind. + + :param \**kw: + Additional keyword arguments are sent to :meth:`.Session.get_bind()` + to allow extensibility of "bind" schemes. + + .. seealso:: + + :ref:`sqlexpression_toplevel` - Tutorial on using Core SQL + constructs. + + :ref:`connections_toplevel` - Further information on direct + statement execution. + + :meth:`.Connection.execute` - core level statement execution + method, which is :meth:`.Session.execute` ultimately uses + in order to execute the statement. - \**kw - Additional keyword arguments are sent to :meth:`get_bind()` - which locates a connectable to use for the execution. - Subclasses of :class:`Session` may override this. - """ clause = expression._literal_as_text(clause) - engine = self.get_bind(mapper, clause=clause, **kw) + if bind is None: + bind = self.get_bind(mapper, clause=clause, **kw) - return self._connection_for_bind(engine, close_with_result=True).execute( - clause, params or {}) + return self._connection_for_bind( + bind, close_with_result=True).execute(clause, params or {}) - def scalar(self, clause, params=None, mapper=None, **kw): - """Like execute() but return a scalar result.""" - - return self.execute(clause, params=params, mapper=mapper, **kw).scalar() + def scalar(self, clause, params=None, mapper=None, bind=None, **kw): + """Like :meth:`~.Session.execute` but return a scalar result.""" + + return self.execute( + clause, params=params, mapper=mapper, bind=bind, **kw).scalar() def close(self): """Close this Session. @@ -751,17 +1122,46 @@ class Session(object): not use any connection resources until they are first needed. """ + self._close_impl(invalidate=False) + + def invalidate(self): + """Close this Session, using connection invalidation. + + This is a variant of :meth:`.Session.close` that will additionally + ensure that the :meth:`.Connection.invalidate` method will be called + on all :class:`.Connection` objects. This can be called when + the database is known to be in a state where the connections are + no longer safe to be used. + + E.g.:: + + try: + sess = Session() + sess.add(User()) + sess.commit() + except gevent.Timeout: + sess.invalidate() + raise + except: + sess.rollback() + raise + + This clears all items and ends any transaction in progress. + + If this session were created with ``autocommit=False``, a new + transaction is immediately begun. Note that this new transaction does + not use any connection resources until they are first needed. + + .. versionadded:: 0.9.9 + + """ + self._close_impl(invalidate=True) + + def _close_impl(self, invalidate): self.expunge_all() if self.transaction is not None: - for transaction in self.transaction._iterate_parents(): - transaction.close() - - @classmethod - def close_all(cls): - """Close *all* sessions in memory.""" - - for sess in _sessions.values(): - sess.close() + for transaction in self.transaction._iterate_self_and_parents(): + transaction.close(invalidate) def expunge_all(self): """Remove all object instances from this ``Session``. @@ -770,63 +1170,109 @@ class Session(object): ``Session``. """ - for state in self.identity_map.all_states() + list(self._new): - state.detach() + all_states = self.identity_map.all_states() + list(self._new) self.identity_map = self._identity_cls() self._new = {} self._deleted = {} - # TODO: need much more test coverage for bind_mapper() and similar ! - # TODO: + crystalize + document resolution order vis. bind_mapper/bind_table + statelib.InstanceState._detach_states( + all_states, self + ) + + def _add_bind(self, key, bind): + try: + insp = inspect(key) + except sa_exc.NoInspectionAvailable: + if not isinstance(key, type): + raise sa_exc.ArgumentError( + "Not an acceptable bind target: %s" % key) + else: + self.__binds[key] = bind + else: + if insp.is_selectable: + self.__binds[insp] = bind + elif insp.is_mapper: + self.__binds[insp.class_] = bind + for selectable in insp._all_tables: + self.__binds[selectable] = bind + else: + raise sa_exc.ArgumentError( + "Not an acceptable bind target: %s" % key) def bind_mapper(self, mapper, bind): - """Bind operations for a mapper to a Connectable. + """Associate a :class:`.Mapper` with a "bind", e.g. a :class:`.Engine` + or :class:`.Connection`. - mapper - A mapper instance or mapped class - - bind - Any Connectable: a ``Engine`` or ``Connection``. - - All subsequent operations involving this mapper will use the given - `bind`. + The given mapper is added to a lookup used by the + :meth:`.Session.get_bind` method. """ - if isinstance(mapper, type): - mapper = _class_mapper(mapper) - - self.__binds[mapper.base_mapper] = bind - for t in mapper._all_tables: - self.__binds[t] = bind + self._add_bind(mapper, bind) def bind_table(self, table, bind): - """Bind operations on a Table to a Connectable. + """Associate a :class:`.Table` with a "bind", e.g. a :class:`.Engine` + or :class:`.Connection`. - table - A ``Table`` instance - - bind - Any Connectable: a ``Engine`` or ``Connection``. - - All subsequent operations involving this ``Table`` will use the - given `bind`. + The given mapper is added to a lookup used by the + :meth:`.Session.get_bind` method. """ - self.__binds[table] = bind + self._add_bind(table, bind) - def get_bind(self, mapper, clause=None): - """Return an engine corresponding to the given arguments. + def get_bind(self, mapper=None, clause=None): + """Return a "bind" to which this :class:`.Session` is bound. - All arguments are optional. + The "bind" is usually an instance of :class:`.Engine`, + except in the case where the :class:`.Session` has been + explicitly bound directly to a :class:`.Connection`. - mapper - Optional, a ``Mapper`` or mapped class + For a multiply-bound or unbound :class:`.Session`, the + ``mapper`` or ``clause`` arguments are used to determine the + appropriate bind to return. - clause - Optional, A ClauseElement (i.e. select(), text(), etc.) + Note that the "mapper" argument is usually present + when :meth:`.Session.get_bind` is called via an ORM + operation such as a :meth:`.Session.query`, each + individual INSERT/UPDATE/DELETE operation within a + :meth:`.Session.flush`, call, etc. + + The order of resolution is: + + 1. if mapper given and session.binds is present, + locate a bind based on mapper. + 2. if clause given and session.binds is present, + locate a bind based on :class:`.Table` objects + found in the given clause present in session.binds. + 3. if session.bind is present, return that. + 4. if clause given, attempt to return a bind + linked to the :class:`.MetaData` ultimately + associated with the clause. + 5. if mapper given, attempt to return a bind + linked to the :class:`.MetaData` ultimately + associated with the :class:`.Table` or other + selectable to which the mapper is mapped. + 6. No bind can be found, :exc:`~sqlalchemy.exc.UnboundExecutionError` + is raised. + + :param mapper: + Optional :func:`.mapper` mapped class or instance of + :class:`.Mapper`. The bind can be derived from a :class:`.Mapper` + first by consulting the "binds" map associated with this + :class:`.Session`, and secondly by consulting the :class:`.MetaData` + associated with the :class:`.Table` to which the :class:`.Mapper` + is mapped for a bind. + + :param clause: + A :class:`.ClauseElement` (i.e. :func:`~.sql.expression.select`, + :func:`~.sql.expression.text`, + etc.). If the ``mapper`` argument is not present or could not + produce a bind, the given expression construct will be searched + for a bound element, typically a :class:`.Table` associated with + bound :class:`.MetaData`. """ + if mapper is clause is None: if self.bind: return self.bind @@ -836,15 +1282,23 @@ class Session(object): "Connection, and no context was provided to locate " "a binding.") - c_mapper = mapper is not None and _class_to_mapper(mapper) or None - - # manually bound? + if mapper is not None: + try: + mapper = inspect(mapper) + except sa_exc.NoInspectionAvailable: + if isinstance(mapper, type): + raise exc.UnmappedClassError(mapper) + else: + raise + if self.__binds: - if c_mapper: - if c_mapper.base_mapper in self.__binds: - return self.__binds[c_mapper.base_mapper] - elif c_mapper.mapped_table in self.__binds: - return self.__binds[c_mapper.mapped_table] + if mapper: + for cls in mapper.class_.__mro__: + if cls in self.__binds: + return self.__binds[cls] + if clause is None: + clause = mapper.mapped_table + if clause is not None: for t in sql_util.find_tables(clause, include_crud=True): if t in self.__binds: @@ -856,34 +1310,72 @@ class Session(object): if isinstance(clause, sql.expression.ClauseElement) and clause.bind: return clause.bind - if c_mapper and c_mapper.mapped_table.bind: - return c_mapper.mapped_table.bind + if mapper and mapper.mapped_table.bind: + return mapper.mapped_table.bind context = [] if mapper is not None: - context.append('mapper %s' % c_mapper) + context.append('mapper %s' % mapper) if clause is not None: context.append('SQL expression') - + raise sa_exc.UnboundExecutionError( "Could not locate a bind configured on %s or this Session" % ( - ', '.join(context))) + ', '.join(context))) def query(self, *entities, **kwargs): - """Return a new ``Query`` object corresponding to this ``Session``.""" + """Return a new :class:`.Query` object corresponding to this + :class:`.Session`.""" return self._query_cls(entities, self, **kwargs) + @property + @util.contextmanager + def no_autoflush(self): + """Return a context manager that disables autoflush. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + .. versionadded:: 0.7.6 + + """ + autoflush = self.autoflush + self.autoflush = False + try: + yield self + finally: + self.autoflush = autoflush + def _autoflush(self): if self.autoflush and not self._flushing: - self.flush() - - def _finalize_loaded(self, states): - for state, dict_ in states.items(): - state.commit_all(dict_, self.identity_map) + try: + self.flush() + except sa_exc.StatementError as e: + # note we are reraising StatementError as opposed to + # raising FlushError with "chaining" to remain compatible + # with code that catches StatementError, IntegrityError, + # etc. + e.add_detail( + "raised as a result of Query-invoked autoflush; " + "consider using a session.no_autoflush block if this " + "flush is occurring prematurely") + util.raise_from_cause(e) def refresh(self, instance, attribute_names=None, lockmode=None): - """Refresh the attributes on the given instance. + """Expire and refresh the attributes on the given instance. A query will be issued to the database and all attributes will be refreshed with their current database value. @@ -895,62 +1387,143 @@ class Session(object): Eagerly-loaded relational attributes will eagerly load within the single refresh operation. + Note that a highly isolated transaction will return the same values as + were previously read in that same transaction, regardless of changes + in database state outside of that transaction - usage of + :meth:`~Session.refresh` usually only makes sense if non-ORM SQL + statement were emitted in the ongoing transaction, or if autocommit + mode is turned on. + :param attribute_names: optional. An iterable collection of - string attribute names indicating a subset of attributes to + string attribute names indicating a subset of attributes to be refreshed. - - :param lockmode: Passed to the :class:`~sqlalchemy.orm.query.Query` + + :param lockmode: Passed to the :class:`~sqlalchemy.orm.query.Query` as used by :meth:`~sqlalchemy.orm.query.Query.with_lockmode`. - + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.expire_all` + """ try: state = attributes.instance_state(instance) except exc.NO_STATE: raise exc.UnmappedInstanceError(instance) - self._validate_persistent(state) - if self.query(_object_mapper(instance))._get( + + self._expire_state(state, attribute_names) + + if loading.load_on_ident( + self.query(object_mapper(instance)), state.key, refresh_state=state, lockmode=lockmode, only_load_props=attribute_names) is None: raise sa_exc.InvalidRequestError( "Could not refresh instance '%s'" % - mapperutil.instance_str(instance)) + instance_str(instance)) def expire_all(self): - """Expires all persistent instances within this Session.""" + """Expires all persistent instances within this Session. + When any attributes on a persistent instance is next accessed, + a query will be issued using the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire individual objects and individual attributes + on those objects, use :meth:`Session.expire`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire_all` should not be needed when + autocommit is ``False``, assuming the transaction is isolated. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + """ for state in self.identity_map.all_states(): - _expire_state(state, state.dict, None, instance_dict=self.identity_map) + state._expire(state.dict, self.identity_map._modified) def expire(self, instance, attribute_names=None): """Expire the attributes on an instance. - Marks the attributes of an instance as out of date. When an expired - attribute is next accessed, query will be issued to the database and - the attributes will be refreshed with their current database value. - ``expire()`` is a lazy variant of ``refresh()``. + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, a query will be issued to the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. - The ``attribute_names`` argument is an iterable collection - of attribute names indicating a subset of attributes to be - expired. + To expire all objects in the :class:`.Session` simultaneously, + use :meth:`Session.expire_all`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire` only makes sense for the specific + case that a non-ORM SQL statement was emitted in the current + transaction. + + :param instance: The instance to be refreshed. + :param attribute_names: optional list of string attribute names + indicating a subset of attributes to be expired. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` """ try: state = attributes.instance_state(instance) except exc.NO_STATE: raise exc.UnmappedInstanceError(instance) + self._expire_state(state, attribute_names) + + def _expire_state(self, state, attribute_names): self._validate_persistent(state) if attribute_names: - _expire_state(state, state.dict, - attribute_names=attribute_names, instance_dict=self.identity_map) + state._expire_attributes(state.dict, attribute_names) else: # pre-fetch the full cascade since the expire is going to # remove associations - cascaded = list(_cascade_state_iterator('refresh-expire', state)) - _expire_state(state, state.dict, None, instance_dict=self.identity_map) - for (state, m, o) in cascaded: - _expire_state(state, state.dict, None, instance_dict=self.identity_map) + cascaded = list(state.manager.mapper.cascade_iterator( + 'refresh-expire', state)) + self._conditional_expire(state) + for o, m, st_, dct_ in cascaded: + self._conditional_expire(st_) + def _conditional_expire(self, state): + """Expire a state if persistent, else expunge if pending""" + + if state.key: + state._expire(state.dict, self.identity_map._modified) + elif state in self._new: + self._new.pop(state) + state._detach(self) + + @util.deprecated("0.7", "The non-weak-referencing identity map " + "feature is no longer needed.") def prune(self): """Remove unreferenced instances cached in the identity map. @@ -978,60 +1551,114 @@ class Session(object): if state.session_id is not self.hash_key: raise sa_exc.InvalidRequestError( "Instance %s is not present in this Session" % - mapperutil.state_str(state)) - for s, m, o in [(state, None, None)] + list(_cascade_state_iterator('expunge', state)): - self._expunge_state(s) + state_str(state)) - def _expunge_state(self, state): - if state in self._new: - self._new.pop(state) - state.detach() - elif self.identity_map.contains_state(state): - self.identity_map.discard(state) - self._deleted.pop(state, None) - state.detach() + cascaded = list(state.manager.mapper.cascade_iterator( + 'expunge', state)) + self._expunge_states( + [state] + [st_ for o, m, st_, dct_ in cascaded] + ) - def _register_newly_persistent(self, state): - mapper = _state_mapper(state) + def _expunge_states(self, states, to_transient=False): + for state in states: + if state in self._new: + self._new.pop(state) + elif self.identity_map.contains_state(state): + self.identity_map.safe_discard(state) + self._deleted.pop(state, None) + elif self.transaction: + # state is "detached" from being deleted, but still present + # in the transaction snapshot + self.transaction._deleted.pop(state, None) + statelib.InstanceState._detach_states( + states, self, to_transient=to_transient) - # prevent against last minute dereferences of the object - obj = state.obj() - if obj is not None: + def _register_newly_persistent(self, states): + pending_to_persistent = self.dispatch.pending_to_persistent or None + for state in states: + mapper = _state_mapper(state) - instance_key = mapper._identity_key_from_state(state) + # prevent against last minute dereferences of the object + obj = state.obj() + if obj is not None: + + instance_key = mapper._identity_key_from_state(state) + + if _none_set.intersection(instance_key[1]) and \ + not mapper.allow_partial_pks or \ + _none_set.issuperset(instance_key[1]): + raise exc.FlushError( + "Instance %s has a NULL identity key. If this is an " + "auto-generated value, check that the database table " + "allows generation of new primary key values, and " + "that the mapped Column object is configured to " + "expect these generated values. Ensure also that " + "this flush() is not occurring at an inappropriate " + "time, such aswithin a load() event." + % state_str(state) + ) + + if state.key is None: + state.key = instance_key + elif state.key != instance_key: + # primary key switch. use safe_discard() in case another + # state has already replaced this one in the identity + # map (see test/orm/test_naturalpks.py ReversePKsTest) + self.identity_map.safe_discard(state) + if state in self.transaction._key_switches: + orig_key = self.transaction._key_switches[state][0] + else: + orig_key = state.key + self.transaction._key_switches[state] = ( + orig_key, instance_key) + state.key = instance_key + + self.identity_map.replace(state) + + statelib.InstanceState._commit_all_states( + ((state, state.dict) for state in states), + self.identity_map + ) + + self._register_altered(states) + + if pending_to_persistent is not None: + for state in states: + pending_to_persistent(self, state.obj()) - if state.key is None: - state.key = instance_key - elif state.key != instance_key: - # primary key switch. - # use discard() in case another state has already replaced this - # one in the identity map (see test/orm/test_naturalpks.py ReversePKsTest) - self.identity_map.discard(state) - state.key = instance_key - - self.identity_map.replace(state) - state.commit_all(state.dict, self.identity_map) - # remove from new last, might be the last strong ref - if state in self._new: - if self._enable_transaction_accounting and self.transaction: - self.transaction._new[state] = True + for state in set(states).intersection(self._new): self._new.pop(state) - def _remove_newly_deleted(self, state): + def _register_altered(self, states): if self._enable_transaction_accounting and self.transaction: - self.transaction._deleted[state] = True + for state in states: + if state in self._new: + self.transaction._new[state] = True + else: + self.transaction._dirty[state] = True - self.identity_map.discard(state) - self._deleted.pop(state, None) + def _remove_newly_deleted(self, states): + persistent_to_deleted = self.dispatch.persistent_to_deleted or None + for state in states: + if self._enable_transaction_accounting and self.transaction: + self.transaction._deleted[state] = True - def _save_without_cascade(self, instance): - """Used by scoping.py to save on init without cascade.""" + if persistent_to_deleted is not None: + # get a strong reference before we pop out of + # self._deleted + obj = state.obj() - state = _state_for_unsaved_instance(instance, create=True) - self._save_impl(state) + self.identity_map.safe_discard(state) + self._deleted.pop(state, None) + state._deleted = True + # can't call state._detach() here, because this state + # is still in the transaction snapshot and needs to be + # tracked as part of that + if persistent_to_deleted is not None: + persistent_to_deleted(self, obj) - def add(self, instance): + def add(self, instance, _warn=True): """Place an object in the ``Session``. Its state will be persisted to the database on the next flush @@ -1041,23 +1668,34 @@ class Session(object): is ``expunge()``. """ - state = _state_for_unknown_persistence_instance(instance) + if _warn and self._warn_on_events: + self._flush_warning("Session.add()") + + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + self._save_or_update_state(state) def add_all(self, instances): """Add the given collection of instances to this ``Session``.""" + if self._warn_on_events: + self._flush_warning("Session.add_all()") + for instance in instances: - self.add(instance) + self.add(instance, _warn=False) def _save_or_update_state(self, state): self._save_or_update_impl(state) - self._cascade_save_or_update(state) - def _cascade_save_or_update(self, state): - for state, mapper in _cascade_unknown_state_iterator( - 'save-update', state, halt_on=self.__contains__): - self._save_or_update_impl(state) + mapper = _state_mapper(state) + for o, m, st_, dct_ in mapper.cascade_iterator( + 'save-update', + state, + halt_on=self._contains_state): + self._save_or_update_impl(st_) def delete(self, instance): """Mark an instance as deleted. @@ -1065,76 +1703,130 @@ class Session(object): The database delete operation occurs upon ``flush()``. """ + if self._warn_on_events: + self._flush_warning("Session.delete()") + try: state = attributes.instance_state(instance) except exc.NO_STATE: raise exc.UnmappedInstanceError(instance) + self._delete_impl(state, instance, head=True) + + def _delete_impl(self, state, obj, head): + if state.key is None: - raise sa_exc.InvalidRequestError( - "Instance '%s' is not persisted" % - mapperutil.state_str(state)) + if head: + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persisted" % + state_str(state)) + else: + return + + to_attach = self._before_attach(state, obj) if state in self._deleted: return - - # ensure object is attached to allow the - # cascade operation to load deferred attributes - # and collections - self._attach(state) - # grab the cascades before adding the item to the deleted list - # so that autoflush does not delete the item - cascade_states = list(_cascade_state_iterator('delete', state)) - - self._deleted[state] = state.obj() self.identity_map.add(state) - for state, m, o in cascade_states: - self._delete_impl(state) + if to_attach: + self._after_attach(state, obj) - def merge(self, instance, load=True, **kw): - """Copy the state an instance onto the persistent instance with the same identifier. + if head: + # grab the cascades before adding the item to the deleted list + # so that autoflush does not delete the item + # the strong reference to the instance itself is significant here + cascade_states = list(state.manager.mapper.cascade_iterator( + 'delete', state)) - If there is no persistent instance currently associated with the - session, it will be loaded. Return the persistent instance. If the - given instance is unsaved, save a copy of and return it as a newly - persistent instance. The given instance does not become associated - with the session. + self._deleted[state] = obj + + if head: + for o, m, st_, dct_ in cascade_states: + self._delete_impl(st_, o, False) + + def merge(self, instance, load=True): + """Copy the state of a given instance into a corresponding instance + within this :class:`.Session`. + + :meth:`.Session.merge` examines the primary key attributes of the + source instance, and attempts to reconcile it with an instance of the + same primary key in the session. If not found locally, it attempts + to load the object from the database based on primary key, and if + none can be located, creates a new instance. The state of each + attribute on the source instance is then copied to the target + instance. The resulting target instance is then returned by the + method; the original source instance is left unmodified, and + un-associated with the :class:`.Session` if not already. This operation cascades to associated instances if the association is mapped with ``cascade="merge"``. + See :ref:`unitofwork_merging` for a detailed discussion of merging. + + .. versionchanged:: 1.1 - :meth:`.Session.merge` will now reconcile + pending objects with overlapping primary keys in the same way + as persistent. See :ref:`change_3601` for discussion. + + :param instance: Instance to be merged. + :param load: Boolean, when False, :meth:`.merge` switches into + a "high performance" mode which causes it to forego emitting history + events as well as all database access. This flag is used for + cases such as transferring graphs of objects into a :class:`.Session` + from a second level cache, or to transfer just-loaded objects + into the :class:`.Session` owned by a worker thread or process + without re-querying the database. + + The ``load=False`` use case adds the caveat that the given + object has to be in a "clean" state, that is, has no pending changes + to be flushed - even if the incoming object is detached from any + :class:`.Session`. This is so that when + the merge operation populates local attributes and + cascades to related objects and + collections, the values can be "stamped" onto the + target object as is, without generating any history or attribute + events, and without the need to reconcile the incoming data with + any existing related objects or collections that might not + be loaded. The resulting objects from ``load=False`` are always + produced as "clean", so it is only appropriate that the given objects + should be "clean" as well, else this suggests a mis-use of the + method. + + """ - if 'dont_load' in kw: - load = not kw['dont_load'] - util.warn_deprecated("dont_load=True has been renamed to load=False.") - + + if self._warn_on_events: + self._flush_warning("Session.merge()") + _recursive = {} - + _resolve_conflict_map = {} + if load: # flush current contents if we expect to load data self._autoflush() - - _object_mapper(instance) # verify mapped + + object_mapper(instance) # verify mapped autoflush = self.autoflush try: self.autoflush = False return self._merge( - attributes.instance_state(instance), - attributes.instance_dict(instance), - load=load, _recursive=_recursive) + attributes.instance_state(instance), + attributes.instance_dict(instance), + load=load, _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map) finally: self.autoflush = autoflush - - def _merge(self, state, state_dict, load=True, _recursive=None): + + def _merge(self, state, state_dict, load=True, _recursive=None, + _resolve_conflict_map=None): mapper = _state_mapper(state) if state in _recursive: return _recursive[state] new_instance = False key = state.key - + if key is None: if not load: raise sa_exc.InvalidRequestError( @@ -1143,10 +1835,15 @@ class Session(object): "all changes on mapped instances before merging with " "load=False.") key = mapper._identity_key_from_state(state) + key_is_persistent = attributes.NEVER_SET not in key[1] + else: + key_is_persistent = True if key in self.identity_map: merged = self.identity_map[key] - + elif key_is_persistent and key in _resolve_conflict_map: + merged = _resolve_conflict_map[key] + elif not load: if state.modified: raise sa_exc.InvalidRequestError( @@ -1158,14 +1855,15 @@ class Session(object): merged_state.key = key self._update_impl(merged_state) new_instance = True - - elif not _none_set.issubset(key[1]) or \ - (mapper.allow_partial_pks and - not _none_set.issuperset(key[1])): + + elif key_is_persistent and ( + not _none_set.intersection(key[1]) or + (mapper.allow_partial_pks and + not _none_set.issuperset(key[1]))): merged = self.query(mapper.class_).get(key[1]) else: merged = None - + if merged is None: merged = mapper.class_manager.new_instance() merged_state = attributes.instance_state(merged) @@ -1175,66 +1873,122 @@ class Session(object): else: merged_state = attributes.instance_state(merged) merged_dict = attributes.instance_dict(merged) - + _recursive[state] = merged + _resolve_conflict_map[key] = merged # check that we didn't just pull the exact same - # state out. + # state out. if state is not merged_state: + # version check if applicable + if mapper.version_id_col is not None: + existing_version = mapper._get_state_attr_by_column( + state, + state_dict, + mapper.version_id_col, + passive=attributes.PASSIVE_NO_INITIALIZE) + + merged_version = mapper._get_state_attr_by_column( + merged_state, + merged_dict, + mapper.version_id_col, + passive=attributes.PASSIVE_NO_INITIALIZE) + + if existing_version is not attributes.PASSIVE_NO_RESULT and \ + merged_version is not attributes.PASSIVE_NO_RESULT and \ + existing_version != merged_version: + raise exc.StaleDataError( + "Version id '%s' on merged state %s " + "does not match existing version '%s'. " + "Leave the version attribute unset when " + "merging to update the most recent version." + % ( + existing_version, + state_str(merged_state), + merged_version + )) + merged_state.load_path = state.load_path merged_state.load_options = state.load_options - + + # since we are copying load_options, we need to copy + # the callables_ that would have been generated by those + # load_options. + # assumes that the callables we put in state.callables_ + # are not instance-specific (which they should not be) + merged_state._copy_callables(state) + for prop in mapper.iterate_properties: - prop.merge(self, state, state_dict, merged_state, merged_dict, load, _recursive) + prop.merge(self, state, state_dict, + merged_state, merged_dict, + load, _recursive, _resolve_conflict_map) if not load: # remove any history - merged_state.commit_all(merged_dict, self.identity_map) + merged_state._commit_all(merged_dict, self.identity_map) if new_instance: - merged_state._run_on_load(merged) + merged_state.manager.dispatch.load(merged_state, None) return merged - @classmethod - def identity_key(cls, *args, **kwargs): - return mapperutil.identity_key(*args, **kwargs) - - @classmethod - def object_session(cls, instance): - """Return the ``Session`` to which an object belongs.""" - - return object_session(instance) - def _validate_persistent(self, state): if not self.identity_map.contains_state(state): raise sa_exc.InvalidRequestError( "Instance '%s' is not persistent within this Session" % - mapperutil.state_str(state)) + state_str(state)) def _save_impl(self, state): if state.key is not None: raise sa_exc.InvalidRequestError( - "Object '%s' already has an identity - it can't be registered " - "as pending" % mapperutil.state_str(state)) - - self._attach(state) - if state not in self._new: - self._new[state] = state.obj() - state.insert_order = len(self._new) + "Object '%s' already has an identity - " + "it can't be registered as pending" % state_str(state)) - def _update_impl(self, state): - if (self.identity_map.contains_state(state) and - state not in self._deleted): - return - + obj = state.obj() + to_attach = self._before_attach(state, obj) + if state not in self._new: + self._new[state] = obj + state.insert_order = len(self._new) + if to_attach: + self._after_attach(state, obj) + + def _update_impl(self, state, revert_deletion=False): if state.key is None: raise sa_exc.InvalidRequestError( "Instance '%s' is not persisted" % - mapperutil.state_str(state)) + state_str(state)) + + if state._deleted: + if revert_deletion: + if not state._attached: + return + del state._deleted + else: + raise sa_exc.InvalidRequestError( + "Instance '%s' has been deleted. " + "Use the make_transient() " + "function to send this object back " + "to the transient state." % + state_str(state) + ) + + obj = state.obj() + + # check for late gc + if obj is None: + return + + to_attach = self._before_attach(state, obj) - self._attach(state) self._deleted.pop(state, None) - self.identity_map.add(state) + if revert_deletion: + self.identity_map.replace(state) + else: + self.identity_map.add(state) + + if to_attach: + self._after_attach(state, obj) + elif revert_deletion: + self.dispatch.deleted_to_persistent(self, obj) def _save_or_update_impl(self, state): if state.key is None: @@ -1242,36 +1996,87 @@ class Session(object): else: self._update_impl(state) - def _delete_impl(self, state): - if state in self._deleted: - return + def enable_relationship_loading(self, obj): + """Associate an object with this :class:`.Session` for related + object loading. - if state.key is None: - return - - self._attach(state) - self._deleted[state] = state.obj() - self.identity_map.add(state) - - def _attach(self, state): - if state.key and \ - state.key in self.identity_map and \ - not self.identity_map.contains_state(state): - raise sa_exc.InvalidRequestError( - "Can't attach instance %s; another instance with key %s is already present in this session." % - (mapperutil.state_str(state), state.key) - ) - - if state.session_id and state.session_id is not self.hash_key: + .. warning:: + + :meth:`.enable_relationship_loading` exists to serve special + use cases and is not recommended for general use. + + Accesses of attributes mapped with :func:`.relationship` + will attempt to load a value from the database using this + :class:`.Session` as the source of connectivity. The values + will be loaded based on foreign key values present on this + object - it follows that this functionality + generally only works for many-to-one-relationships. + + The object will be attached to this session, but will + **not** participate in any persistence operations; its state + for almost all purposes will remain either "transient" or + "detached", except for the case of relationship loading. + + Also note that backrefs will often not work as expected. + Altering a relationship-bound attribute on the target object + may not fire off a backref event, if the effective value + is what was already loaded from a foreign-key-holding value. + + The :meth:`.Session.enable_relationship_loading` method is + similar to the ``load_on_pending`` flag on :func:`.relationship`. + Unlike that flag, :meth:`.Session.enable_relationship_loading` allows + an object to remain transient while still being able to load + related items. + + To make a transient object associated with a :class:`.Session` + via :meth:`.Session.enable_relationship_loading` pending, add + it to the :class:`.Session` using :meth:`.Session.add` normally. + + :meth:`.Session.enable_relationship_loading` does not improve + behavior when the ORM is used normally - object references should be + constructed at the object level, not at the foreign key level, so + that they are present in an ordinary way before flush() + proceeds. This method is not intended for general use. + + .. versionadded:: 0.8 + + .. seealso:: + + ``load_on_pending`` at :func:`.relationship` - this flag + allows per-relationship loading of many-to-ones on items that + are pending. + + """ + state = attributes.instance_state(obj) + to_attach = self._before_attach(state, obj) + state._load_pending = True + if to_attach: + self._after_attach(state, obj) + + def _before_attach(self, state, obj): + if state.session_id == self.hash_key: + return False + + if state.session_id and state.session_id in _sessions: raise sa_exc.InvalidRequestError( "Object '%s' is already attached to session '%s' " - "(this is '%s')" % (mapperutil.state_str(state), + "(this is '%s')" % (state_str(state), state.session_id, self.hash_key)) - - if state.session_id != self.hash_key: - state.session_id = self.hash_key - for ext in self.extensions: - ext.after_attach(self, state.obj()) + + self.dispatch.before_attach(self, obj) + + return True + + def _after_attach(self, state, obj): + state.session_id = self.hash_key + if state.modified and state._strong_obj is None: + state._strong_obj = obj + self.dispatch.after_attach(self, obj) + + if state.key: + self.dispatch.detached_to_persistent(self, obj) + else: + self.dispatch.transient_to_pending(self, obj) def __contains__(self, instance): """Return True if the instance is associated with this session. @@ -1287,9 +2092,12 @@ class Session(object): return self._contains_state(state) def __iter__(self): - """Iterate over all pending or persistent instances within this Session.""" + """Iterate over all pending or persistent instances within this + Session. - return iter(list(self._new.values()) + self.identity_map.values()) + """ + return iter( + list(self._new.values()) + list(self.identity_map.values())) def _contains_state(self, state): return state in self._new or self.identity_map.contains_state(state) @@ -1300,43 +2108,52 @@ class Session(object): Writes out all pending object creations, deletions and modifications to the database as INSERTs, DELETEs, UPDATEs, etc. Operations are automatically ordered by the Session's unit of work dependency - solver.. + solver. Database operations will be issued in the current transactional - context and do not affect the state of the transaction. You may - flush() as often as you like within a transaction to move changes from - Python to the database's transaction buffer. + context and do not affect the state of the transaction, unless an + error occurs, in which case the entire transaction is rolled back. + You may flush() as often as you like within a transaction to move + changes from Python to the database's transaction buffer. For ``autocommit`` Sessions with no active manual transaction, flush() will create a transaction on the fly that surrounds the entire set of - operations int the flush. + operations into the flush. - objects - Optional; a list or tuple collection. Restricts the flush operation - to only these objects, rather than all pending changes. - Deprecated - this flag prevents the session from properly maintaining - accounting among inter-object relations and can cause invalid results. + :param objects: Optional; restricts the flush operation to operate + only on elements that are in the given collection. + + This feature is for an extremely narrow set of use cases where + particular objects may need to be operated upon before the + full flush() occurs. It is not intended for general use. """ - if objects: - util.warn_deprecated( - "The 'objects' argument to session.flush() is deprecated; " - "Please do not add objects to the session which should not yet be persisted.") - if self._flushing: raise sa_exc.InvalidRequestError("Session is already flushing") - + + if self._is_clean(): + return try: self._flushing = True self._flush(objects) finally: self._flushing = False - + + def _flush_warning(self, method): + util.warn( + "Usage of the '%s' operation is not currently supported " + "within the execution stage of the flush process. " + "Results may not be consistent. Consider using alternative " + "event listeners or connection-level operations instead." + % method) + + def _is_clean(self): + return not self.identity_map.check_modified() and \ + not self._deleted and \ + not self._new + def _flush(self, objects=None): - if (not self.identity_map.check_modified() and - not self._deleted and not self._new): - return dirty = self._dirty_states if not dirty and not self._deleted and not self._new: @@ -1345,11 +2162,12 @@ class Session(object): flush_context = UOWTransaction(self) - if self.extensions: - for ext in self.extensions: - ext.before_flush(self, flush_context, objects) + if self.dispatch.before_flush: + self.dispatch.before_flush(self, flush_context, objects) + # re-establish "dirty states" in case the listeners + # added dirty = self._dirty_states - + deleted = set(self._deleted) new = set(self._new) @@ -1377,20 +2195,12 @@ class Session(object): proc = new.union(dirty).intersection(objset).difference(deleted) else: proc = new.union(dirty).difference(deleted) - + for state in proc: - is_orphan = _state_mapper(state)._is_orphan(state) - if is_orphan and not _state_has_identity(state): - path = ", nor ".join( - ["any parent '%s' instance " - "via that classes' '%s' attribute" % - (cls.__name__, key) - for (key, cls) in chain(*(m.delete_orphans for m in _state_mapper(state).iterate_to_root()))]) - raise exc.FlushError( - "Instance %s is an unsaved, pending instance and is an " - "orphan (is not attached to %s)" % ( - mapperutil.state_str(state), path)) - flush_context.register_object(state, isdelete=is_orphan) + is_orphan = ( + _state_mapper(state)._is_orphan(state) and state.has_identity) + _reg = flush_context.register_object(state, isdelete=is_orphan) + assert _reg, "Failed to add object to the flush context!" processed.add(state) # put all remaining deletes into the flush context. @@ -1399,80 +2209,465 @@ class Session(object): else: proc = deleted.difference(processed) for state in proc: - flush_context.register_object(state, isdelete=True) + _reg = flush_context.register_object(state, isdelete=True) + assert _reg, "Failed to add object to the flush context!" - if len(flush_context.tasks) == 0: + if not flush_context.has_work: return flush_context.transaction = transaction = self.begin( subtransactions=True) try: - flush_context.execute() + self._warn_on_events = True + try: + flush_context.execute() + finally: + self._warn_on_events = False + + self.dispatch.after_flush(self, flush_context) + + flush_context.finalize_flush_changes() + + if not objects and self.identity_map._modified: + len_ = len(self.identity_map._modified) + + statelib.InstanceState._commit_all_states( + [(state, state.dict) for state in + self.identity_map._modified], + instance_dict=self.identity_map) + util.warn("Attribute history events accumulated on %d " + "previously clean instances " + "within inner-flush event handlers have been " + "reset, and will not result in database updates. " + "Consider using set_committed_value() within " + "inner-flush event handlers to avoid this warning." + % len_) + + # useful assertions: + # if not objects: + # assert not self.identity_map._modified + # else: + # assert self.identity_map._modified == \ + # self.identity_map._modified.difference(objects) + + self.dispatch.after_flush_postexec(self, flush_context) - for ext in self.extensions: - ext.after_flush(self, flush_context) transaction.commit() + except: - transaction.rollback() - raise - - flush_context.finalize_flush_changes() + with util.safe_reraise(): + transaction.rollback(_capture_exception=True) - # useful assertions: - #if not objects: - # assert not self.identity_map._modified - #else: - # assert self.identity_map._modified == self.identity_map._modified.difference(objects) - #self.identity_map._modified.clear() - - for ext in self.extensions: - ext.after_flush_postexec(self, flush_context) + def bulk_save_objects( + self, objects, return_defaults=False, update_changed_only=True): + """Perform a bulk save of the given list of objects. - def is_modified(self, instance, include_collections=True, passive=False): - """Return True if instance has modified attributes. + The bulk save feature allows mapped objects to be used as the + source of simple INSERT and UPDATE operations which can be more easily + grouped together into higher performing "executemany" + operations; the extraction of data from the objects is also performed + using a lower-latency process that ignores whether or not attributes + have actually been modified in the case of UPDATEs, and also ignores + SQL expressions. - This method retrieves a history instance for each instrumented - attribute on the instance and performs a comparison of the current - value to its previously committed value. Note that instances present - in the 'dirty' collection may result in a value of ``False`` when - tested with this method. + The objects as given are not added to the session and no additional + state is established on them, unless the ``return_defaults`` flag + is also set, in which case primary key attributes and server-side + default values will be populated. - `include_collections` indicates if multivalued collections should be - included in the operation. Setting this to False is a way to detect - only local-column based properties (i.e. scalar columns or many-to-one - foreign keys) that would result in an UPDATE for this instance upon - flush. + .. versionadded:: 1.0.0 - The `passive` flag indicates if unloaded attributes and collections - should not be loaded in the course of performing this test. + .. warning:: + + The bulk save feature allows for a lower-latency INSERT/UPDATE + of rows at the expense of most other unit-of-work features. + Features such as object management, relationship handling, + and SQL clause support are **silently omitted** in favor of raw + INSERT/UPDATES of records. + + **Please read the list of caveats at** :ref:`bulk_operations` + **before using this method, and fully test and confirm the + functionality of all code developed using these systems.** + + :param objects: a list of mapped object instances. The mapped + objects are persisted as is, and are **not** associated with the + :class:`.Session` afterwards. + + For each object, whether the object is sent as an INSERT or an + UPDATE is dependent on the same rules used by the :class:`.Session` + in traditional operation; if the object has the + :attr:`.InstanceState.key` + attribute set, then the object is assumed to be "detached" and + will result in an UPDATE. Otherwise, an INSERT is used. + + In the case of an UPDATE, statements are grouped based on which + attributes have changed, and are thus to be the subject of each + SET clause. If ``update_changed_only`` is False, then all + attributes present within each object are applied to the UPDATE + statement, which may help in allowing the statements to be grouped + together into a larger executemany(), and will also reduce the + overhead of checking history on attributes. + + :param return_defaults: when True, rows that are missing values which + generate defaults, namely integer primary key defaults and sequences, + will be inserted **one at a time**, so that the primary key value + is available. In particular this will allow joined-inheritance + and other multi-table mappings to insert correctly without the need + to provide primary key values ahead of time; however, + :paramref:`.Session.bulk_save_objects.return_defaults` **greatly + reduces the performance gains** of the method overall. + + :param update_changed_only: when True, UPDATE statements are rendered + based on those attributes in each state that have logged changes. + When False, all attributes present are rendered into the SET clause + with the exception of primary key attributes. + + .. seealso:: + + :ref:`bulk_operations` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_update_mappings` """ + for (mapper, isupdate), states in itertools.groupby( + (attributes.instance_state(obj) for obj in objects), + lambda state: (state.mapper, state.key is not None) + ): + self._bulk_save_mappings( + mapper, states, isupdate, True, + return_defaults, update_changed_only, False) + + def bulk_insert_mappings( + self, mapper, mappings, return_defaults=False, render_nulls=False): + """Perform a bulk insert of the given list of mapping dictionaries. + + The bulk insert feature allows plain Python dictionaries to be used as + the source of simple INSERT operations which can be more easily + grouped together into higher performing "executemany" + operations. Using dictionaries, there is no "history" or session + state management features in use, reducing latency when inserting + large numbers of simple rows. + + The values within the dictionaries as given are typically passed + without modification into Core :meth:`.Insert` constructs, after + organizing the values within them across the tables to which + the given mapper is mapped. + + .. versionadded:: 1.0.0 + + .. warning:: + + The bulk insert feature allows for a lower-latency INSERT + of rows at the expense of most other unit-of-work features. + Features such as object management, relationship handling, + and SQL clause support are **silently omitted** in favor of raw + INSERT of records. + + **Please read the list of caveats at** :ref:`bulk_operations` + **before using this method, and fully test and confirm the + functionality of all code developed using these systems.** + + :param mapper: a mapped class, or the actual :class:`.Mapper` object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a list of dictionaries, each one containing the state + of the mapped row to be inserted, in terms of the attribute names + on the mapped class. If the mapping refers to multiple tables, + such as a joined-inheritance mapping, each dictionary must contain + all keys to be populated into all tables. + + :param return_defaults: when True, rows that are missing values which + generate defaults, namely integer primary key defaults and sequences, + will be inserted **one at a time**, so that the primary key value + is available. In particular this will allow joined-inheritance + and other multi-table mappings to insert correctly without the need + to provide primary + key values ahead of time; however, + :paramref:`.Session.bulk_insert_mappings.return_defaults` + **greatly reduces the performance gains** of the method overall. + If the rows + to be inserted only refer to a single table, then there is no + reason this flag should be set as the returned default information + is not used. + + :param render_nulls: When True, a value of ``None`` will result + in a NULL value being included in the INSERT statement, rather + than the column being omitted from the INSERT. This allows all + the rows being INSERTed to have the identical set of columns which + allows the full set of rows to be batched to the DBAPI. Normally, + each column-set that contains a different combination of NULL values + than the previous row must omit a different series of columns from + the rendered INSERT statement, which means it must be emitted as a + separate statement. By passing this flag, the full set of rows + are guaranteed to be batchable into one batch; the cost however is + that server-side defaults which are invoked by an omitted column will + be skipped, so care must be taken to ensure that these are not + necessary. + + .. warning:: + + When this flag is set, **server side default SQL values will + not be invoked** for those columns that are inserted as NULL; + the NULL value will be sent explicitly. Care must be taken + to ensure that no server-side default functions need to be + invoked for the operation as a whole. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`bulk_operations` + + :meth:`.Session.bulk_save_objects` + + :meth:`.Session.bulk_update_mappings` + + """ + self._bulk_save_mappings( + mapper, mappings, False, False, + return_defaults, False, render_nulls) + + def bulk_update_mappings(self, mapper, mappings): + """Perform a bulk update of the given list of mapping dictionaries. + + The bulk update feature allows plain Python dictionaries to be used as + the source of simple UPDATE operations which can be more easily + grouped together into higher performing "executemany" + operations. Using dictionaries, there is no "history" or session + state management features in use, reducing latency when updating + large numbers of simple rows. + + .. versionadded:: 1.0.0 + + .. warning:: + + The bulk update feature allows for a lower-latency UPDATE + of rows at the expense of most other unit-of-work features. + Features such as object management, relationship handling, + and SQL clause support are **silently omitted** in favor of raw + UPDATES of records. + + **Please read the list of caveats at** :ref:`bulk_operations` + **before using this method, and fully test and confirm the + functionality of all code developed using these systems.** + + :param mapper: a mapped class, or the actual :class:`.Mapper` object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a list of dictionaries, each one containing the state + of the mapped row to be updated, in terms of the attribute names + on the mapped class. If the mapping refers to multiple tables, + such as a joined-inheritance mapping, each dictionary may contain + keys corresponding to all tables. All those keys which are present + and are not part of the primary key are applied to the SET clause + of the UPDATE statement; the primary key values, which are required, + are applied to the WHERE clause. + + + .. seealso:: + + :ref:`bulk_operations` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_save_objects` + + """ + self._bulk_save_mappings( + mapper, mappings, True, False, False, False, False) + + def _bulk_save_mappings( + self, mapper, mappings, isupdate, isstates, + return_defaults, update_changed_only, render_nulls): + mapper = _class_to_mapper(mapper) + self._flushing = True + + transaction = self.begin( + subtransactions=True) try: - state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + if isupdate: + persistence._bulk_update( + mapper, mappings, transaction, + isstates, update_changed_only) + else: + persistence._bulk_insert( + mapper, mappings, transaction, + isstates, return_defaults, render_nulls) + transaction.commit() + + except: + with util.safe_reraise(): + transaction.rollback(_capture_exception=True) + finally: + self._flushing = False + + def is_modified(self, instance, include_collections=True, + passive=True): + r"""Return ``True`` if the given instance has locally + modified attributes. + + This method retrieves the history for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value, if any. + + It is in effect a more expensive and accurate + version of checking for the given instance in the + :attr:`.Session.dirty` collection; a full test for + each attribute's net "dirty" status is performed. + + E.g.:: + + return session.is_modified(someobject) + + .. versionchanged:: 0.8 + When using SQLAlchemy 0.7 and earlier, the ``passive`` + flag should **always** be explicitly set to ``True``, + else SQL loads/autoflushes may proceed which can affect + the modified state itself: + ``session.is_modified(someobject, passive=True)``\ . + In 0.8 and above, the behavior is corrected and + this flag is ignored. + + A few caveats to this method apply: + + * Instances present in the :attr:`.Session.dirty` collection may + report ``False`` when tested with this method. This is because + the object may have received change events via attribute mutation, + thus placing it in :attr:`.Session.dirty`, but ultimately the state + is the same as that loaded from the database, resulting in no net + change here. + * Scalar attributes may not have recorded the previously set + value when a new value was applied, if the attribute was not loaded, + or was expired, at the time the new value was received - in these + cases, the attribute is assumed to have a change, even if there is + ultimately no net change against its database value. SQLAlchemy in + most cases does not need the "old" value when a set event occurs, so + it skips the expense of a SQL call if the old value isn't present, + based on the assumption that an UPDATE of the scalar value is + usually needed, and in those few cases where it isn't, is less + expensive on average than issuing a defensive SELECT. + + The "old" value is fetched unconditionally upon set only if the + attribute container has the ``active_history`` flag set to ``True``. + This flag is set typically for primary key attributes and scalar + object references that are not a simple many-to-one. To set this + flag for any arbitrary mapped column, use the ``active_history`` + argument with :func:`.column_property`. + + :param instance: mapped instance to be tested for pending changes. + :param include_collections: Indicates if multivalued collections + should be included in the operation. Setting this to ``False`` is a + way to detect only local-column based properties (i.e. scalar columns + or many-to-one foreign keys) that would result in an UPDATE for this + instance upon flush. + :param passive: + + .. versionchanged:: 0.8 + Ignored for backwards compatibility. + When using SQLAlchemy 0.7 and earlier, this flag should always + be set to ``True``. + + """ + state = object_state(instance) + + if not state.modified: + return False + dict_ = state.dict + for attr in state.manager.attributes: if \ - ( - not include_collections and - hasattr(attr.impl, 'get_collection') - ) or not hasattr(attr.impl, 'get_history'): + ( + not include_collections and + hasattr(attr.impl, 'get_collection') + ) or not hasattr(attr.impl, 'get_history'): continue - + (added, unchanged, deleted) = \ - attr.impl.get_history(state, dict_, passive=passive) - + attr.impl.get_history(state, dict_, + passive=attributes.NO_CHANGE) + if added or deleted: return True - return False + else: + return False @property def is_active(self): - """True if this Session has an active transaction.""" + """True if this :class:`.Session` is in "transaction mode" and + is not in "partial rollback" state. + The :class:`.Session` in its default mode of ``autocommit=False`` + is essentially always in "transaction mode", in that a + :class:`.SessionTransaction` is associated with it as soon as + it is instantiated. This :class:`.SessionTransaction` is immediately + replaced with a new one as soon as it is ended, due to a rollback, + commit, or close operation. + + "Transaction mode" does *not* indicate whether + or not actual database connection resources are in use; the + :class:`.SessionTransaction` object coordinates among zero or more + actual database transactions, and starts out with none, accumulating + individual DBAPI connections as different data sources are used + within its scope. The best way to track when a particular + :class:`.Session` has actually begun to use DBAPI resources is to + implement a listener using the :meth:`.SessionEvents.after_begin` + method, which will deliver both the :class:`.Session` as well as the + target :class:`.Connection` to a user-defined event listener. + + The "partial rollback" state refers to when an "inner" transaction, + typically used during a flush, encounters an error and emits a + rollback of the DBAPI connection. At this point, the + :class:`.Session` is in "partial rollback" and awaits for the user to + call :meth:`.Session.rollback`, in order to close out the + transaction stack. It is in this "partial rollback" period that the + :attr:`.is_active` flag returns False. After the call to + :meth:`.Session.rollback`, the :class:`.SessionTransaction` is + replaced with a new one and :attr:`.is_active` returns ``True`` again. + + When a :class:`.Session` is used in ``autocommit=True`` mode, the + :class:`.SessionTransaction` is only instantiated within the scope + of a flush call, or when :meth:`.Session.begin` is called. So + :attr:`.is_active` will always be ``False`` outside of a flush or + :meth:`.Session.begin` block in this mode, and will be ``True`` + within the :meth:`.Session.begin` block as long as it doesn't enter + "partial rollback" state. + + From all the above, it follows that the only purpose to this flag is + for application frameworks that wish to detect is a "rollback" is + necessary within a generic error handling routine, for + :class:`.Session` objects that would otherwise be in + "partial rollback" mode. In a typical integration case, this is also + not necessary as it is standard practice to emit + :meth:`.Session.rollback` unconditionally within the outermost + exception catch. + + To track the transactional state of a :class:`.Session` fully, + use event listeners, primarily the :meth:`.SessionEvents.after_begin`, + :meth:`.SessionEvents.after_commit`, + :meth:`.SessionEvents.after_rollback` and related events. + + """ return self.transaction and self.transaction.is_active + identity_map = None + """A mapping of object identities to objects themselves. + + Iterating through ``Session.identity_map.values()`` provides + access to the full set of persistent objects (i.e., those + that have row identity) currently in the session. + + .. seealso:: + + :func:`.identity_key` - helper function to produce the keys used + in this dictionary. + + """ + @property def _dirty_states(self): """The set of all persistent states considered dirty. @@ -1487,6 +2682,10 @@ class Session(object): def dirty(self): """The set of all persistent instances considered dirty. + E.g.:: + + some_mapped_object in session.dirty + Instances are considered dirty when they were modified but not deleted. @@ -1500,7 +2699,7 @@ class Session(object): it's only done at flush time). To check if an instance has actionable net changes to its - attributes, use the is_modified() method. + attributes, use the :meth:`.Session.is_modified` method. """ return util.IdentitySet( @@ -1512,93 +2711,260 @@ class Session(object): def deleted(self): "The set of all instances marked as 'deleted' within this ``Session``" - return util.IdentitySet(self._deleted.values()) + return util.IdentitySet(list(self._deleted.values())) @property def new(self): "The set of all instances marked as 'new' within this ``Session``." - return util.IdentitySet(self._new.values()) + return util.IdentitySet(list(self._new.values())) -_expire_state = state.InstanceState.expire_attributes - -UOWEventHandler = unitofwork.UOWEventHandler -_sessions = weakref.WeakValueDictionary() +class sessionmaker(_SessionClassMethods): + """A configurable :class:`.Session` factory. -def _cascade_state_iterator(cascade, state, **kwargs): - mapper = _state_mapper(state) - # yield the state, object, mapper. yielding the object - # allows the iterator's results to be held in a list without - # states being garbage collected - for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs): - yield attributes.instance_state(o), o, m + The :class:`.sessionmaker` factory generates new + :class:`.Session` objects when called, creating them given + the configurational arguments established here. -def _cascade_unknown_state_iterator(cascade, state, **kwargs): - mapper = _state_mapper(state) - for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs): - yield _state_for_unknown_persistence_instance(o), m + e.g.:: -def _state_for_unsaved_instance(instance, create=False): - try: - state = attributes.instance_state(instance) - except AttributeError: - raise exc.UnmappedInstanceError(instance) - if state: - if state.key is not None: - raise sa_exc.InvalidRequestError( - "Instance '%s' is already persistent" % - mapperutil.state_str(state)) - elif create: - manager = attributes.manager_of_class(instance.__class__) - if manager is None: - raise exc.UnmappedInstanceError(instance) - state = manager.setup_instance(instance) - else: - raise exc.UnmappedInstanceError(instance) + # global scope + Session = sessionmaker(autoflush=False) - return state + # later, in a local scope, create and use a session: + sess = Session() -def _state_for_unknown_persistence_instance(instance): - try: - state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + Any keyword arguments sent to the constructor itself will override the + "configured" keywords:: + + Session = sessionmaker() + + # bind an individual session to a connection + sess = Session(bind=connection) + + The class also includes a method :meth:`.configure`, which can + be used to specify additional keyword arguments to the factory, which + will take effect for subsequent :class:`.Session` objects generated. + This is usually used to associate one or more :class:`.Engine` objects + with an existing :class:`.sessionmaker` factory before it is first + used:: + + # application starts + Session = sessionmaker() + + # ... later + engine = create_engine('sqlite:///foo.db') + Session.configure(bind=engine) + + sess = Session() + + .. seealso: + + :ref:`session_getting` - introductory text on creating + sessions using :class:`.sessionmaker`. + + """ + + def __init__(self, bind=None, class_=Session, autoflush=True, + autocommit=False, + expire_on_commit=True, + info=None, **kw): + r"""Construct a new :class:`.sessionmaker`. + + All arguments here except for ``class_`` correspond to arguments + accepted by :class:`.Session` directly. See the + :meth:`.Session.__init__` docstring for more details on parameters. + + :param bind: a :class:`.Engine` or other :class:`.Connectable` with + which newly created :class:`.Session` objects will be associated. + :param class_: class to use in order to create new :class:`.Session` + objects. Defaults to :class:`.Session`. + :param autoflush: The autoflush setting to use with newly created + :class:`.Session` objects. + :param autocommit: The autocommit setting to use with newly created + :class:`.Session` objects. + :param expire_on_commit=True: the expire_on_commit setting to use + with newly created :class:`.Session` objects. + :param info: optional dictionary of information that will be available + via :attr:`.Session.info`. Note this dictionary is *updated*, not + replaced, when the ``info`` parameter is specified to the specific + :class:`.Session` construction operation. + + .. versionadded:: 0.9.0 + + :param \**kw: all other keyword arguments are passed to the + constructor of newly created :class:`.Session` objects. + + """ + kw['bind'] = bind + kw['autoflush'] = autoflush + kw['autocommit'] = autocommit + kw['expire_on_commit'] = expire_on_commit + if info is not None: + kw['info'] = info + self.kw = kw + # make our own subclass of the given class, so that + # events can be associated with it specifically. + self.class_ = type(class_.__name__, (class_,), {}) + + def __call__(self, **local_kw): + """Produce a new :class:`.Session` object using the configuration + established in this :class:`.sessionmaker`. + + In Python, the ``__call__`` method is invoked on an object when + it is "called" in the same way as a function:: + + Session = sessionmaker() + session = Session() # invokes sessionmaker.__call__() + + """ + for k, v in self.kw.items(): + if k == 'info' and 'info' in local_kw: + d = v.copy() + d.update(local_kw['info']) + local_kw['info'] = d + else: + local_kw.setdefault(k, v) + return self.class_(**local_kw) + + def configure(self, **new_kw): + """(Re)configure the arguments for this sessionmaker. + + e.g.:: + + Session = sessionmaker() + + Session.configure(bind=create_engine('sqlite://')) + """ + self.kw.update(new_kw) + + def __repr__(self): + return "%s(class_=%r,%s)" % ( + self.__class__.__name__, + self.class_.__name__, + ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()) + ) - return state def make_transient(instance): - """Make the given instance 'transient'. - - This will remove its association with any - session and additionally will remove its "identity key", - such that it's as though the object were newly constructed, - except retaining its values. - + """Alter the state of the given instance so that it is :term:`transient`. + + .. note:: + + :func:`.make_transient` is a special-case function for + advanced use cases only. + + The given mapped instance is assumed to be in the :term:`persistent` or + :term:`detached` state. The function will remove its association with any + :class:`.Session` as well as its :attr:`.InstanceState.identity`. The + effect is that the object will behave as though it were newly constructed, + except retaining any attribute / collection values that were loaded at the + time of the call. The :attr:`.InstanceState.deleted` flag is also reset + if this object had been deleted as a result of using + :meth:`.Session.delete`. + + .. warning:: + + :func:`.make_transient` does **not** "unexpire" or otherwise eagerly + load ORM-mapped attributes that are not currently loaded at the time + the function is called. This includes attributes which: + + * were expired via :meth:`.Session.expire` + + * were expired as the natural effect of committing a session + transaction, e.g. :meth:`.Session.commit` + + * are normally :term:`lazy loaded` but are not currently loaded + + * are "deferred" via :ref:`deferred` and are not yet loaded + + * were not present in the query which loaded this object, such as that + which is common in joined table inheritance and other scenarios. + + After :func:`.make_transient` is called, unloaded attributes such + as those above will normally resolve to the value ``None`` when + accessed, or an empty collection for a collection-oriented attribute. + As the object is transient and un-associated with any database + identity, it will no longer retrieve these values. + + .. seealso:: + + :func:`.make_transient_to_detached` + """ state = attributes.instance_state(instance) s = _state_session(state) if s: - s._expunge_state(state) - del state.key - - + s._expunge_states([state]) + + # remove expired state + state.expired_attributes.clear() + + # remove deferred callables + if state.callables: + del state.callables + + if state.key: + del state.key + if state._deleted: + del state._deleted + + +def make_transient_to_detached(instance): + """Make the given transient instance :term:`detached`. + + .. note:: + + :func:`.make_transient_to_detached` is a special-case function for + advanced use cases only. + + All attribute history on the given instance + will be reset as though the instance were freshly loaded + from a query. Missing attributes will be marked as expired. + The primary key attributes of the object, which are required, will be made + into the "key" of the instance. + + The object can then be added to a session, or merged + possibly with the load=False flag, at which point it will look + as if it were loaded that way, without emitting SQL. + + This is a special use case function that differs from a normal + call to :meth:`.Session.merge` in that a given persistent state + can be manufactured without any SQL calls. + + .. versionadded:: 0.9.5 + + .. seealso:: + + :func:`.make_transient` + + """ + state = attributes.instance_state(instance) + if state.session_id or state.key: + raise sa_exc.InvalidRequestError( + "Given object must be transient") + state.key = state.mapper._identity_key_from_state(state) + if state._deleted: + del state._deleted + state._commit_all(state.dict) + state._expire_attributes(state.dict, state.unloaded) + + def object_session(instance): - """Return the ``Session`` to which instance belongs, or None.""" + """Return the :class:`.Session` to which the given instance belongs. - return _state_session(attributes.instance_state(instance)) + This is essentially the same as the :attr:`.InstanceState.session` + accessor. See that attribute for details. -def _state_session(state): - if state.session_id: - try: - return _sessions[state.session_id] - except KeyError: - pass - return None + """ -# Lazy initialization to avoid circular imports -unitofwork.object_session = object_session -unitofwork._state_session = _state_session -from sqlalchemy.orm import mapper -mapper._expire_state = _expire_state -mapper._state_session = _state_session + try: + state = attributes.instance_state(instance) + except exc.NO_STATE: + raise exc.UnmappedInstanceError(instance) + else: + return _state_session(state) + + +_new_sessionid = util.counter() diff --git a/sqlalchemy/orm/state.py b/sqlalchemy/orm/state.py index 25466b3..0fba240 100644 --- a/sqlalchemy/orm/state.py +++ b/sqlalchemy/orm/state.py @@ -1,287 +1,616 @@ -from sqlalchemy.util import EMPTY_SET +# orm/state.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Defines instrumentation of instances. + +This module is usually not directly visible to user applications, but +defines a large part of the ORM's interactivity. + +""" + import weakref -from sqlalchemy import util -from sqlalchemy.orm.attributes import PASSIVE_NO_RESULT, PASSIVE_OFF, \ - NEVER_SET, NO_VALUE, manager_of_class, \ - ATTR_WAS_SET -from sqlalchemy.orm import attributes, exc as orm_exc, interfaces +from .. import util +from .. import inspection +from . import exc as orm_exc, interfaces +from .path_registry import PathRegistry +from .base import PASSIVE_NO_RESULT, SQL_OK, NEVER_SET, ATTR_WAS_SET, \ + NO_VALUE, PASSIVE_NO_INITIALIZE, INIT_OK, PASSIVE_OFF +from . import base -import sys -attributes.state = sys.modules['sqlalchemy.orm.state'] -class InstanceState(object): - """tracks state information at the instance level.""" +@inspection._self_inspects +class InstanceState(interfaces.InspectionAttr): + """tracks state information at the instance level. + + The :class:`.InstanceState` is a key object used by the + SQLAlchemy ORM in order to track the state of an object; + it is created the moment an object is instantiated, typically + as a result of :term:`instrumentation` which SQLAlchemy applies + to the ``__init__()`` method of the class. + + :class:`.InstanceState` is also a semi-public object, + available for runtime inspection as to the state of a + mapped instance, including information such as its current + status within a particular :class:`.Session` and details + about data on individual attributes. The public API + in order to acquire a :class:`.InstanceState` object + is to use the :func:`.inspect` system:: + + >>> from sqlalchemy import inspect + >>> insp = inspect(some_mapped_object) + + .. seealso:: + + :ref:`core_inspection_toplevel` + + """ session_id = None key = None runid = None - load_options = EMPTY_SET + load_options = util.EMPTY_SET load_path = () insert_order = None - mutable_dict = None _strong_obj = None modified = False expired = False - + _deleted = False + _load_pending = False + is_instance = True + + callables = () + """A namespace where a per-state loader callable can be associated. + + In SQLAlchemy 1.0, this is only used for lazy loaders / deferred + loaders that were set up via query option. + + Previously, callables was used also to indicate expired attributes + by storing a link to the InstanceState itself in this dictionary. + This role is now handled by the expired_attributes set. + + """ + def __init__(self, obj, manager): self.class_ = obj.__class__ self.manager = manager self.obj = weakref.ref(obj, self._cleanup) + self.committed_state = {} + self.expired_attributes = set() + + expired_attributes = None + """The set of keys which are 'expired' to be loaded by + the manager's deferred scalar loader, assuming no pending + changes. + + see also the ``unmodified`` collection which is intersected + against this set when a refresh operation occurs.""" @util.memoized_property - def committed_state(self): - return {} - + def attrs(self): + """Return a namespace representing each attribute on + the mapped object, including its current value + and history. + + The returned object is an instance of :class:`.AttributeState`. + This object allows inspection of the current data + within an attribute as well as attribute history + since the last flush. + + """ + return util.ImmutableProperties( + dict( + (key, AttributeState(self, key)) + for key in self.manager + ) + ) + + @property + def transient(self): + """Return true if the object is :term:`transient`. + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is None and \ + not self._attached + + @property + def pending(self): + """Return true if the object is :term:`pending`. + + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is None and \ + self._attached + + @property + def deleted(self): + """Return true if the object is :term:`deleted`. + + An object that is in the deleted state is guaranteed to + not be within the :attr:`.Session.identity_map` of its parent + :class:`.Session`; however if the session's transaction is rolled + back, the object will be restored to the persistent state and + the identity map. + + .. note:: + + The :attr:`.InstanceState.deleted` attribute refers to a specific + state of the object that occurs between the "persistent" and + "detached" states; once the object is :term:`detached`, the + :attr:`.InstanceState.deleted` attribute **no longer returns + True**; in order to detect that a state was deleted, regardless + of whether or not the object is associated with a :class:`.Session`, + use the :attr:`.InstanceState.was_deleted` accessor. + + .. versionadded: 1.1 + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is not None and \ + self._attached and self._deleted + + @property + def was_deleted(self): + """Return True if this object is or was previously in the + "deleted" state and has not been reverted to persistent. + + This flag returns True once the object was deleted in flush. + When the object is expunged from the session either explicitly + or via transaction commit and enters the "detached" state, + this flag will continue to report True. + + .. versionadded:: 1.1 - added a local method form of + :func:`.orm.util.was_deleted`. + + .. seealso:: + + :attr:`.InstanceState.deleted` - refers to the "deleted" state + + :func:`.orm.util.was_deleted` - standalone function + + :ref:`session_object_states` + + """ + return self._deleted + + @property + def persistent(self): + """Return true if the object is :term:`persistent`. + + An object that is in the persistent state is guaranteed to + be within the :attr:`.Session.identity_map` of its parent + :class:`.Session`. + + .. versionchanged:: 1.1 The :attr:`.InstanceState.persistent` + accessor no longer returns True for an object that was + "deleted" within a flush; use the :attr:`.InstanceState.deleted` + accessor to detect this state. This allows the "persistent" + state to guarantee membership in the identity map. + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is not None and \ + self._attached and not self._deleted + + @property + def detached(self): + """Return true if the object is :term:`detached`. + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is not None and not self._attached + + @property + @util.dependencies("sqlalchemy.orm.session") + def _attached(self, sessionlib): + return self.session_id is not None and \ + self.session_id in sessionlib._sessions + + @property + @util.dependencies("sqlalchemy.orm.session") + def session(self, sessionlib): + """Return the owning :class:`.Session` for this instance, + or ``None`` if none available. + + Note that the result here can in some cases be *different* + from that of ``obj in session``; an object that's been deleted + will report as not ``in session``, however if the transaction is + still in progress, this attribute will still refer to that session. + Only when the transaction is completed does the object become + fully detached under normal circumstances. + + """ + return sessionlib._state_session(self) + + @property + def object(self): + """Return the mapped object represented by this + :class:`.InstanceState`.""" + return self.obj() + + @property + def identity(self): + """Return the mapped identity of the mapped object. + This is the primary key identity as persisted by the ORM + which can always be passed directly to + :meth:`.Query.get`. + + Returns ``None`` if the object has no primary key identity. + + .. note:: + An object which is :term:`transient` or :term:`pending` + does **not** have a mapped identity until it is flushed, + even if its attributes include primary key values. + + """ + if self.key is None: + return None + else: + return self.key[1] + + @property + def identity_key(self): + """Return the identity key for the mapped object. + + This is the key used to locate the object within + the :attr:`.Session.identity_map` mapping. It contains + the identity as returned by :attr:`.identity` within it. + + + """ + # TODO: just change .key to .identity_key across + # the board ? probably + return self.key + @util.memoized_property def parents(self): return {} @util.memoized_property - def pending(self): + def _pending_mutations(self): return {} @util.memoized_property - def callables(self): - return {} - - def detach(self): - if self.session_id: - try: - del self.session_id - except AttributeError: - pass + def mapper(self): + """Return the :class:`.Mapper` used for this mapepd object.""" + return self.manager.mapper - def dispose(self): - self.detach() + @property + def has_identity(self): + """Return ``True`` if this object has an identity key. + + This should always have the same value as the + expression ``state.persistent or state.detached``. + + """ + return bool(self.key) + + @classmethod + def _detach_states(self, states, session, to_transient=False): + persistent_to_detached = \ + session.dispatch.persistent_to_detached or None + deleted_to_detached = \ + session.dispatch.deleted_to_detached or None + pending_to_transient = \ + session.dispatch.pending_to_transient or None + persistent_to_transient = \ + session.dispatch.persistent_to_transient or None + + for state in states: + deleted = state._deleted + pending = state.key is None + persistent = not pending and not deleted + + state.session_id = None + + if to_transient and state.key: + del state.key + if persistent: + if to_transient: + if persistent_to_transient is not None: + obj = state.obj() + if obj is not None: + persistent_to_transient(session, obj) + elif persistent_to_detached is not None: + obj = state.obj() + if obj is not None: + persistent_to_detached(session, obj) + elif deleted and deleted_to_detached is not None: + obj = state.obj() + if obj is not None: + deleted_to_detached(session, obj) + elif pending and pending_to_transient is not None: + obj = state.obj() + if obj is not None: + pending_to_transient(session, obj) + + state._strong_obj = None + + def _detach(self, session=None): + if session: + InstanceState._detach_states([self], session) + else: + self.session_id = self._strong_obj = None + + def _dispose(self): + self._detach() del self.obj - + def _cleanup(self, ref): + """Weakref callback cleanup. + + This callable cleans out the state when it is being garbage + collected. + + this _cleanup **assumes** that there are no strong refs to us! + Will not work otherwise! + + """ instance_dict = self._instance_dict() - if instance_dict: - try: - instance_dict.remove(self) - except AssertionError: - pass - # remove possible cycles - self.__dict__.pop('callables', None) - self.dispose() - + if instance_dict is not None: + instance_dict._fast_discard(self) + del self._instance_dict + + # we can't possibly be in instance_dict._modified + # b.c. this is weakref cleanup only, that set + # is strong referencing! + # assert self not in instance_dict._modified + + self.session_id = self._strong_obj = None + del self.obj + def obj(self): return None - + @property def dict(self): + """Return the instance dict used by the object. + + Under normal circumstances, this is always synonymous + with the ``__dict__`` attribute of the mapped object, + unless an alternative instrumentation system has been + configured. + + In the case that the actual object has been garbage + collected, this accessor returns a blank dictionary. + + """ o = self.obj() if o is not None: - return attributes.instance_dict(o) + return base.instance_dict(o) else: return {} - - @property - def sort_key(self): - return self.key and self.key[1] or (self.insert_order, ) - def initialize_instance(*mixed, **kwargs): - self, instance, args = mixed[0], mixed[1], mixed[2:] + def _initialize_instance(*mixed, **kwargs): + self, instance, args = mixed[0], mixed[1], mixed[2:] # noqa manager = self.manager - for fn in manager.events.on_init: - fn(self, instance, args, kwargs) - - # LESSTHANIDEAL: - # adjust for the case where the InstanceState was created before - # mapper compilation, and this actually needs to be a MutableAttrInstanceState - if manager.mutable_attributes and self.__class__ is not MutableAttrInstanceState: - self.__class__ = MutableAttrInstanceState - self.obj = weakref.ref(self.obj(), self._cleanup) - self.mutable_dict = {} - - try: - return manager.events.original_init(*mixed[1:], **kwargs) - except: - for fn in manager.events.on_init_failure: - fn(self, instance, args, kwargs) - raise + manager.dispatch.init(self, args, kwargs) - def get_history(self, key, **kwargs): - return self.manager.get_impl(key).get_history(self, self.dict, **kwargs) + try: + return manager.original_init(*mixed[1:], **kwargs) + except: + with util.safe_reraise(): + manager.dispatch.init_failure(self, args, kwargs) + + def get_history(self, key, passive): + return self.manager[key].impl.get_history(self, self.dict, passive) def get_impl(self, key): - return self.manager.get_impl(key) + return self.manager[key].impl - def get_pending(self, key): - if key not in self.pending: - self.pending[key] = PendingCollection() - return self.pending[key] + def _get_pending_mutation(self, key): + if key not in self._pending_mutations: + self._pending_mutations[key] = PendingCollection() + return self._pending_mutations[key] - def value_as_iterable(self, key, passive=PASSIVE_OFF): - """return an InstanceState attribute as a list, - regardless of it being a scalar or collection-based - attribute. - - returns None if passive is not PASSIVE_OFF and the getter returns - PASSIVE_NO_RESULT. - """ - - impl = self.get_impl(key) - dict_ = self.dict - x = impl.get(self, dict_, passive=passive) - if x is PASSIVE_NO_RESULT: - return None - elif hasattr(impl, 'get_collection'): - return impl.get_collection(self, dict_, x, passive=passive) - else: - return [x] - - def _run_on_load(self, instance): - self.manager.events.run('on_load', instance) - def __getstate__(self): - d = {'instance':self.obj()} - - d.update( + state_dict = {'instance': self.obj()} + state_dict.update( (k, self.__dict__[k]) for k in ( - 'committed_state', 'pending', 'parents', 'modified', 'expired', - 'callables', 'key', 'load_options', 'mutable_dict' - ) if k in self.__dict__ + 'committed_state', '_pending_mutations', 'modified', + 'expired', 'callables', 'key', 'parents', 'load_options', + 'class_', 'expired_attributes' + ) if k in self.__dict__ ) if self.load_path: - d['load_path'] = interfaces.serialize_path(self.load_path) - return d - - def __setstate__(self, state): - self.obj = weakref.ref(state['instance'], self._cleanup) - self.class_ = state['instance'].__class__ - self.manager = manager = manager_of_class(self.class_) - if manager is None: - raise orm_exc.UnmappedInstanceError( - state['instance'], - "Cannot deserialize object of type %r - no mapper() has" - " been configured for this class within the current Python process!" % - self.class_) - elif manager.mapper and not manager.mapper.compiled: - manager.mapper.compile() - - self.committed_state = state.get('committed_state', {}) - self.pending = state.get('pending', {}) - self.parents = state.get('parents', {}) - self.modified = state.get('modified', False) - self.expired = state.get('expired', False) - self.callables = state.get('callables', {}) - - if self.modified: - self._strong_obj = state['instance'] - + state_dict['load_path'] = self.load_path.serialize() + + state_dict['manager'] = self.manager._serialize(self, state_dict) + + return state_dict + + def __setstate__(self, state_dict): + inst = state_dict['instance'] + if inst is not None: + self.obj = weakref.ref(inst, self._cleanup) + self.class_ = inst.__class__ + else: + # None being possible here generally new as of 0.7.4 + # due to storage of state in "parents". "class_" + # also new. + self.obj = None + self.class_ = state_dict['class_'] + + self.committed_state = state_dict.get('committed_state', {}) + self._pending_mutations = state_dict.get('_pending_mutations', {}) + self.parents = state_dict.get('parents', {}) + self.modified = state_dict.get('modified', False) + self.expired = state_dict.get('expired', False) + if 'callables' in state_dict: + self.callables = state_dict['callables'] + + try: + self.expired_attributes = state_dict['expired_attributes'] + except KeyError: + self.expired_attributes = set() + # 0.9 and earlier compat + for k in list(self.callables): + if self.callables[k] is self: + self.expired_attributes.add(k) + del self.callables[k] + self.__dict__.update([ - (k, state[k]) for k in ( - 'key', 'load_options', 'mutable_dict' - ) if k in state + (k, state_dict[k]) for k in ( + 'key', 'load_options', + ) if k in state_dict ]) - if 'load_path' in state: - self.load_path = interfaces.deserialize_path(state['load_path']) + if 'load_path' in state_dict: + self.load_path = PathRegistry.\ + deserialize(state_dict['load_path']) - def initialize(self, key): - """Set this attribute to an empty value or collection, - based on the AttributeImpl in use.""" - - self.manager.get_impl(key).initialize(self, self.dict) + state_dict['manager'](self, inst, state_dict) - def reset(self, dict_, key): - """Remove the given attribute and any + def _reset(self, dict_, key): + """Remove the given attribute and any callables associated with it.""" - dict_.pop(key, None) - self.callables.pop(key, None) + old = dict_.pop(key, None) + if old is not None and self.manager[key].impl.collection: + self.manager[key].impl._invalidate_collection(old) + self.expired_attributes.discard(key) + if self.callables: + self.callables.pop(key, None) - def expire_attribute_pre_commit(self, dict_, key): - """a fast expire that can be called by column loaders during a load. + def _copy_callables(self, from_): + if 'callables' in from_.__dict__: + self.callables = dict(from_.callables) - The additional bookkeeping is finished up in commit_all(). - - This method is actually called a lot with joined-table - loading, when the second table isn't present in the result. - - """ - dict_.pop(key, None) - self.callables[key] = self - - def set_callable(self, dict_, key, callable_): - """Remove the given attribute and set the given callable - as a loader.""" - - dict_.pop(key, None) - self.callables[key] = callable_ - - def expire_attributes(self, dict_, attribute_names, instance_dict=None): - """Expire all or a group of attributes. - - If all attributes are expired, the "expired" flag is set to True. - - """ - if attribute_names is None: - attribute_names = self.manager.keys() - self.expired = True - if self.modified: - if not instance_dict: - instance_dict = self._instance_dict() - if instance_dict: - instance_dict._modified.discard(self) - else: - instance_dict._modified.discard(self) - - self.modified = False - filter_deferred = True + @classmethod + def _instance_level_callable_processor(cls, manager, fn, key): + impl = manager[key].impl + if impl.collection: + def _set_callable(state, dict_, row): + if 'callables' not in state.__dict__: + state.callables = {} + old = dict_.pop(key, None) + if old is not None: + impl._invalidate_collection(old) + state.callables[key] = fn else: - filter_deferred = False + def _set_callable(state, dict_, row): + if 'callables' not in state.__dict__: + state.callables = {} + state.callables[key] = fn + return _set_callable - to_clear = ( - self.__dict__.get('pending', None), - self.__dict__.get('committed_state', None), - self.mutable_dict + def _expire(self, dict_, modified_set): + self.expired = True + + if self.modified: + modified_set.discard(self) + self.committed_state.clear() + self.modified = False + + self._strong_obj = None + + if '_pending_mutations' in self.__dict__: + del self.__dict__['_pending_mutations'] + + if 'parents' in self.__dict__: + del self.__dict__['parents'] + + self.expired_attributes.update( + [impl.key for impl in self.manager._scalar_loader_impls + if impl.expire_missing or impl.key in dict_] ) - + + if self.callables: + for k in self.expired_attributes.intersection(self.callables): + del self.callables[k] + + for k in self.manager._collection_impl_keys.intersection(dict_): + collection = dict_.pop(k) + collection._sa_adapter.invalidated = True + + for key in self.manager._all_key_set.intersection(dict_): + del dict_[key] + + self.manager.dispatch.expire(self, None) + + def _expire_attributes(self, dict_, attribute_names, no_loader=False): + pending = self.__dict__.get('_pending_mutations', None) + + callables = self.callables + for key in attribute_names: impl = self.manager[key].impl - if impl.accepts_scalar_loader and \ - (not filter_deferred or impl.expire_missing or key in dict_): - self.callables[key] = self - dict_.pop(key, None) - - for d in to_clear: - if d is not None: - d.pop(key, None) + if impl.accepts_scalar_loader: + if no_loader and ( + impl.callable_ or + key in callables + ): + continue - def __call__(self, **kw): + self.expired_attributes.add(key) + if callables and key in callables: + del callables[key] + old = dict_.pop(key, None) + if impl.collection and old is not None: + impl._invalidate_collection(old) + + self.committed_state.pop(key, None) + if pending: + pending.pop(key, None) + + self.manager.dispatch.expire(self, attribute_names) + + def _load_expired(self, state, passive): """__call__ allows the InstanceState to act as a deferred callable for loading expired attributes, which is also serializable (picklable). """ - if kw.get('passive') is attributes.PASSIVE_NO_FETCH: - return attributes.PASSIVE_NO_RESULT - + if not passive & SQL_OK: + return PASSIVE_NO_RESULT + toload = self.expired_attributes.\ - intersection(self.unmodified) - + intersection(self.unmodified) + self.manager.deferred_scalar_loader(self, toload) - # if the loader failed, or this + # if the loader failed, or this # instance state didn't have an identity, # the attributes still might be in the callables # dict. ensure they are removed. - for k in toload.intersection(self.callables): - del self.callables[k] - + self.expired_attributes.clear() + return ATTR_WAS_SET @property def unmodified(self): """Return the set of keys which have no uncommitted changes""" - + return set(self.manager).difference(self.committed_state) + def unmodified_intersection(self, keys): + """Return self.unmodified.intersection(keys).""" + + return set(keys).intersection(self.manager).\ + difference(self.committed_state) + @property def unloaded(self): """Return the set of keys which do not have a loaded value. @@ -291,54 +620,61 @@ class InstanceState(object): """ return set(self.manager).\ - difference(self.committed_state).\ - difference(self.dict) + difference(self.committed_state).\ + difference(self.dict) @property - def expired_attributes(self): - """Return the set of keys which are 'expired' to be loaded by - the manager's deferred scalar loader, assuming no pending - changes. - - see also the ``unmodified`` collection which is intersected - against this set when a refresh operation occurs. - - """ - return set([k for k, v in self.callables.items() if v is self]) + def _unloaded_non_object(self): + return self.unloaded.intersection( + attr for attr in self.manager + if self.manager[attr].impl.accepts_scalar_loader + ) def _instance_dict(self): return None - def _is_really_none(self): - return self.obj() - - def modified_event(self, dict_, attr, should_copy, previous, passive=PASSIVE_OFF): - needs_committed = attr.key not in self.committed_state - - if needs_committed: - if previous is NEVER_SET: - if passive: + def _modified_event( + self, dict_, attr, previous, collection=False, force=False): + if not attr.send_modified_events: + return + if attr.key not in self.committed_state or force: + if collection: + if previous is NEVER_SET: if attr.key in dict_: previous = dict_[attr.key] - else: - previous = attr.get(self, dict_) - if should_copy and previous not in (None, NO_VALUE, NEVER_SET): - previous = attr.copy(previous) + if previous not in (None, NO_VALUE, NEVER_SET): + previous = attr.copy(previous) - if needs_committed: - self.committed_state[attr.key] = previous + self.committed_state[attr.key] = previous - if not self.modified: + # assert self._strong_obj is None or self.modified + + if (self.session_id and self._strong_obj is None) \ + or not self.modified: + self.modified = True instance_dict = self._instance_dict() if instance_dict: instance_dict._modified.add(self) - self.modified = True - if self._strong_obj is None: - self._strong_obj = self.obj() + # only create _strong_obj link if attached + # to a session - def commit(self, dict_, keys): + inst = self.obj() + if self.session_id: + self._strong_obj = inst + + if inst is None: + raise orm_exc.ObjectDereferencedError( + "Can't emit change event for attribute '%s' - " + "parent object of type %s has been garbage " + "collected." + % ( + self.manager[attr.key], + base.state_class_str(self) + )) + + def _commit(self, dict_, keys): """Commit attributes. This is used by a partial-attribute load operation to mark committed @@ -348,21 +684,25 @@ class InstanceState(object): this step if a value was not populated in state.dict. """ - class_manager = self.manager for key in keys: - if key in dict_ and key in class_manager.mutable_attributes: - self.committed_state[key] = self.manager[key].impl.copy(dict_[key]) - else: - self.committed_state.pop(key, None) - + self.committed_state.pop(key, None) + self.expired = False - - for key in set(self.callables).\ - intersection(keys).\ - intersection(dict_): - del self.callables[key] - - def commit_all(self, dict_, instance_dict=None): + + self.expired_attributes.difference_update( + set(keys).intersection(dict_)) + + # the per-keys commit removes object-level callables, + # while that of commit_all does not. it's not clear + # if this behavior has a clear rationale, however tests do + # ensure this is what it does. + if self.callables: + for key in set(self.callables).\ + intersection(keys).\ + intersection(dict_): + del self.callables[key] + + def _commit_all(self, dict_, instance_dict=None): """commit all attributes unconditionally. This is used after a flush() or a full load/refresh @@ -371,137 +711,115 @@ class InstanceState(object): - all attributes are marked as "committed" - the "strong dirty reference" is removed - the "modified" flag is set to False - - any "expired" markers/callables for attributes loaded are removed. + - any "expired" markers for scalar attributes loaded are removed. + - lazy load callables for objects / collections *stay* - Attributes marked as "expired" can potentially remain "expired" after this step - if a value was not populated in state.dict. + Attributes marked as "expired" can potentially remain + "expired" after this step if a value was not populated in state.dict. """ - - self.__dict__.pop('committed_state', None) - self.__dict__.pop('pending', None) + self._commit_all_states([(self, dict_)], instance_dict) - if 'callables' in self.__dict__: - callables = self.callables - for key in list(callables): - if key in dict_ and callables[key] is self: - del callables[key] + @classmethod + def _commit_all_states(self, iter, instance_dict=None): + """Mass / highly inlined version of commit_all().""" - for key in self.manager.mutable_attributes: - if key in dict_: - self.committed_state[key] = self.manager[key].impl.copy(dict_[key]) - - if instance_dict and self.modified: - instance_dict._modified.discard(self) - - self.modified = self.expired = False - self._strong_obj = None + for state, dict_ in iter: + state_dict = state.__dict__ + + state.committed_state.clear() + + if '_pending_mutations' in state_dict: + del state_dict['_pending_mutations'] + + state.expired_attributes.difference_update(dict_) + + if instance_dict and state.modified: + instance_dict._modified.discard(state) + + state.modified = state.expired = False + state._strong_obj = None + + +class AttributeState(object): + """Provide an inspection interface corresponding + to a particular attribute on a particular mapped object. + + The :class:`.AttributeState` object is accessed + via the :attr:`.InstanceState.attrs` collection + of a particular :class:`.InstanceState`:: + + from sqlalchemy import inspect + + insp = inspect(some_mapped_object) + attr_state = insp.attrs.some_attribute -class MutableAttrInstanceState(InstanceState): - """InstanceState implementation for objects that reference 'mutable' - attributes. - - Has a more involved "cleanup" handler that checks mutable attributes - for changes upon dereference, resurrecting if needed. - """ - - @util.memoized_property - def mutable_dict(self): - return {} - - def _get_modified(self, dict_=None): - if self.__dict__.get('modified', False): - return True - else: - if dict_ is None: - dict_ = self.dict - for key in self.manager.mutable_attributes: - if self.manager[key].impl.check_mutable_modified(self, dict_): - return True - else: - return False - - def _set_modified(self, value): - self.__dict__['modified'] = value - - modified = property(_get_modified, _set_modified) - + + def __init__(self, state, key): + self.state = state + self.key = key + @property - def unmodified(self): - """a set of keys which have no uncommitted changes""" + def loaded_value(self): + """The current value of this attribute as loaded from the database. - dict_ = self.dict - - return set([ - key for key in self.manager - if (key not in self.committed_state or - (key in self.manager.mutable_attributes and - not self.manager[key].impl.check_mutable_modified(self, dict_)))]) + If the value has not been loaded, or is otherwise not present + in the object's dictionary, returns NO_VALUE. - def _is_really_none(self): - """do a check modified/resurrect. - - This would be called in the extremely rare - race condition that the weakref returned None but - the cleanup handler had not yet established the - __resurrect callable as its replacement. - """ - if self.modified: - self.obj = self.__resurrect - return self.obj() - else: - return None + return self.state.dict.get(self.key, NO_VALUE) + + @property + def value(self): + """Return the value of this attribute. + + This operation is equivalent to accessing the object's + attribute directly or via ``getattr()``, and will fire + off any pending loader callables if needed. - def reset(self, dict_, key): - self.mutable_dict.pop(key, None) - InstanceState.reset(self, dict_, key) - - def _cleanup(self, ref): - """weakref callback. - - This method may be called by an asynchronous - gc. - - If the state shows pending changes, the weakref - is replaced by the __resurrect callable which will - re-establish an object reference on next access, - else removes this InstanceState from the owning - identity map, if any. - """ - if self._get_modified(self.mutable_dict): - self.obj = self.__resurrect - else: - instance_dict = self._instance_dict() - if instance_dict: - try: - instance_dict.remove(self) - except AssertionError: - pass - self.dispose() - - def __resurrect(self): - """A substitute for the obj() weakref function which resurrects.""" - - # store strong ref'ed version of the object; will revert - # to weakref when changes are persisted - - obj = self.manager.new_instance(state=self) - self.obj = weakref.ref(obj, self._cleanup) - self._strong_obj = obj - obj.__dict__.update(self.mutable_dict) + return self.state.manager[self.key].__get__( + self.state.obj(), self.state.class_) + + @property + def history(self): + """Return the current pre-flush change history for + this attribute, via the :class:`.History` interface. + + This method will **not** emit loader callables if the value of the + attribute is unloaded. + + .. seealso:: + + :meth:`.AttributeState.load_history` - retrieve history + using loader callables if the value is not locally present. + + :func:`.attributes.get_history` - underlying function + + """ + return self.state.get_history(self.key, + PASSIVE_NO_INITIALIZE) + + def load_history(self): + """Return the current pre-flush change history for + this attribute, via the :class:`.History` interface. + + This method **will** emit loader callables if the value of the + attribute is unloaded. + + .. seealso:: + + :attr:`.AttributeState.history` + + :func:`.attributes.get_history` - underlying function + + .. versionadded:: 0.9.0 + + """ + return self.state.get_history(self.key, + PASSIVE_OFF ^ INIT_OK) - # re-establishes identity attributes from the key - self.manager.events.run('on_resurrect', self, obj) - - # TODO: don't really think we should run this here. - # resurrect is only meant to preserve the minimal state needed to - # do an UPDATE, not to produce a fully usable object - self._run_on_load(obj) - - return obj class PendingCollection(object): """A writable placeholder for an unloaded collection. @@ -511,6 +829,7 @@ class PendingCollection(object): PendingCollection are applied to it to produce the final result. """ + def __init__(self): self.deleted_items = util.IdentitySet() self.added_items = util.OrderedIdentitySet() @@ -518,10 +837,11 @@ class PendingCollection(object): def append(self, value): if value in self.deleted_items: self.deleted_items.remove(value) - self.added_items.add(value) + else: + self.added_items.add(value) def remove(self, value): if value in self.added_items: self.added_items.remove(value) - self.deleted_items.add(value) - + else: + self.deleted_items.add(value) diff --git a/sqlalchemy/orm/strategies.py b/sqlalchemy/orm/strategies.py index 25c2f83..c70994e 100644 --- a/sqlalchemy/orm/strategies.py +++ b/sqlalchemy/orm/strategies.py @@ -1,622 +1,639 @@ -# strategies.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# orm/strategies.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""sqlalchemy.orm.interfaces.LoaderStrategy +"""sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions.""" -from sqlalchemy import exc as sa_exc -from sqlalchemy import sql, util, log -from sqlalchemy.sql import util as sql_util -from sqlalchemy.sql import visitors, expression, operators -from sqlalchemy.orm import mapper, attributes, interfaces, exc as orm_exc -from sqlalchemy.orm.interfaces import ( - LoaderStrategy, StrategizedOption, MapperOption, PropertyOption, - serialize_path, deserialize_path, StrategizedProperty - ) -from sqlalchemy.orm import session as sessionlib -from sqlalchemy.orm import util as mapperutil +from .. import exc as sa_exc, inspect +from .. import util, log, event +from ..sql import util as sql_util, visitors +from .. import sql +from . import ( + attributes, interfaces, exc as orm_exc, loading, + unitofwork, util as orm_util +) +from .state import InstanceState +from .util import _none_set +from . import properties +from .interfaces import ( + LoaderStrategy, StrategizedProperty +) +from .base import _SET_DEFERRED_EXPIRED, _DEFER_FOR_STATE +from .session import _state_session import itertools -def _register_attribute(strategy, mapper, useobject, - compare_function=None, - typecallable=None, - copy_function=None, - mutable_scalars=False, - uselist=False, - callable_=None, - proxy_property=None, - active_history=False, - impl_class=None, - **kw + +def _register_attribute( + prop, mapper, useobject, + compare_function=None, + typecallable=None, + callable_=None, + proxy_property=None, + active_history=False, + impl_class=None, + **kw ): - prop = strategy.parent_property attribute_ext = list(util.to_list(prop.extension, default=[])) - + + listen_hooks = [] + + uselist = useobject and prop.uselist + if useobject and prop.single_parent: - attribute_ext.insert(0, _SingleParentValidator(prop)) + listen_hooks.append(single_parent_validator) - if prop.key in prop.parent._validators: - attribute_ext.insert(0, - mapperutil.Validator(prop.key, prop.parent._validators[prop.key]) + if prop.key in prop.parent.validators: + fn, opts = prop.parent.validators[prop.key] + listen_hooks.append( + lambda desc, prop: orm_util._validator_events( + desc, + prop.key, fn, **opts) ) - - if useobject: - attribute_ext.append(sessionlib.UOWEventHandler(prop.key)) - - for m in mapper.polymorphic_iterator(): - if prop is m._props.get(prop.key): - - attributes.register_attribute_impl( - m.class_, - prop.key, + if useobject: + listen_hooks.append(unitofwork.track_cascade_events) + + # need to assemble backref listeners + # after the singleparentvalidator, mapper validator + if useobject: + backref = prop.back_populates + if backref: + listen_hooks.append( + lambda desc, prop: attributes.backref_listeners( + desc, + backref, + uselist + ) + ) + + # a single MapperProperty is shared down a class inheritance + # hierarchy, so we set up attribute instrumentation and backref event + # for each mapper down the hierarchy. + + # typically, "mapper" is the same as prop.parent, due to the way + # the configure_mappers() process runs, however this is not strongly + # enforced, and in the case of a second configure_mappers() run the + # mapper here might not be prop.parent; also, a subclass mapper may + # be called here before a superclass mapper. That is, can't depend + # on mappers not already being set up so we have to check each one. + + for m in mapper.self_and_descendants: + if prop is m._props.get(prop.key) and \ + not m.class_manager._attr_has_impl(prop.key): + + desc = attributes.register_attribute_impl( + m.class_, + prop.key, parent_token=prop, - mutable_scalars=mutable_scalars, - uselist=uselist, - copy_function=copy_function, - compare_function=compare_function, - useobject=useobject, - extension=attribute_ext, - trackparent=useobject, + uselist=uselist, + compare_function=compare_function, + useobject=useobject, + extension=attribute_ext, + trackparent=useobject and ( + prop.single_parent or + prop.direction is interfaces.ONETOMANY), typecallable=typecallable, - callable_=callable_, + callable_=callable_, active_history=active_history, impl_class=impl_class, + send_modified_events=not useobject or not prop.viewonly, + doc=prop.doc, **kw - ) + ) + for hook in listen_hooks: + hook(desc, prop) + + +@properties.ColumnProperty.strategy_for(instrument=False, deferred=False) class UninstrumentedColumnLoader(LoaderStrategy): - """Represent the a non-instrumented MapperProperty. - + """Represent a non-instrumented MapperProperty. + The polymorphic_on argument of mapper() often results in this, if the argument is against the with_polymorphic selectable. - + """ - def init(self): + __slots__ = 'columns', + + def __init__(self, parent, strategy_key): + super(UninstrumentedColumnLoader, self).__init__(parent, strategy_key) self.columns = self.parent_property.columns - def setup_query(self, context, entity, path, adapter, - column_collection=None, **kwargs): + def setup_query( + self, context, entity, path, loadopt, adapter, + column_collection=None, **kwargs): for c in self.columns: if adapter: c = adapter.columns[c] column_collection.append(c) - def create_row_processor(self, selectcontext, path, mapper, row, adapter): - return None, None + def create_row_processor( + self, context, path, loadopt, + mapper, result, adapter, populators): + pass + +@log.class_logger +@properties.ColumnProperty.strategy_for(instrument=True, deferred=False) class ColumnLoader(LoaderStrategy): - """Strategize the loading of a plain column-based MapperProperty.""" - - def init(self): + """Provide loading behavior for a :class:`.ColumnProperty`.""" + + __slots__ = 'columns', 'is_composite' + + def __init__(self, parent, strategy_key): + super(ColumnLoader, self).__init__(parent, strategy_key) self.columns = self.parent_property.columns self.is_composite = hasattr(self.parent_property, 'composite_class') - - def setup_query(self, context, entity, path, adapter, - column_collection=None, **kwargs): + + def setup_query( + self, context, entity, path, loadopt, + adapter, column_collection, memoized_populators, **kwargs): + for c in self.columns: if adapter: c = adapter.columns[c] column_collection.append(c) - + + fetch = self.columns[0] + if adapter: + fetch = adapter.columns[fetch] + memoized_populators[self.parent_property] = fetch + def init_class_attribute(self, mapper): self.is_class_level = True coltype = self.columns[0].type # TODO: check all columns ? check for foreign key as well? - active_history = self.columns[0].primary_key + active_history = self.parent_property.active_history or \ + self.columns[0].primary_key or \ + mapper.version_id_col in set(self.columns) - _register_attribute(self, mapper, useobject=False, + _register_attribute( + self.parent_property, mapper, useobject=False, compare_function=coltype.compare_values, - copy_function=coltype.copy_value, - mutable_scalars=self.columns[0].type.is_mutable(), - active_history = active_history - ) - - def create_row_processor(self, selectcontext, path, mapper, row, adapter): - key, col = self.key, self.columns[0] - if adapter: - col = adapter.columns[col] - - if col is not None and col in row: - def new_execute(state, dict_, row): - dict_[key] = row[col] - else: - def new_execute(state, dict_, row): - state.expire_attribute_pre_commit(dict_, key) - return new_execute, None - -log.class_logger(ColumnLoader) - -class CompositeColumnLoader(ColumnLoader): - """Strategize the loading of a composite column-based MapperProperty.""" - - def init_class_attribute(self, mapper): - self.is_class_level = True - self.logger.info("%s register managed composite attribute", self) - - def copy(obj): - if obj is None: - return None - return self.parent_property.\ - composite_class(*obj.__composite_values__()) - - def compare(a, b): - if a is None or b is None: - return a is b - - for col, aprop, bprop in zip(self.columns, - a.__composite_values__(), - b.__composite_values__()): - if not col.type.compare_values(aprop, bprop): - return False - else: - return True - - _register_attribute(self, mapper, useobject=False, - compare_function=compare, - copy_function=copy, - mutable_scalars=True - #active_history ? + active_history=active_history ) - def create_row_processor(self, selectcontext, path, mapper, - row, adapter): - key = self.key - columns = self.columns - composite_class = self.parent_property.composite_class - if adapter: - columns = [adapter.columns[c] for c in columns] - - for c in columns: - if c not in row: - def new_execute(state, dict_, row): - state.expire_attribute_pre_commit(dict_, key) + def create_row_processor( + self, context, path, + loadopt, mapper, result, adapter, populators): + # look through list of columns represented here + # to see which, if any, is present in the row. + for col in self.columns: + if adapter: + col = adapter.columns[col] + getter = result._getter(col, False) + if getter: + populators["quick"].append((self.key, getter)) break else: - def new_execute(state, dict_, row): - dict_[key] = composite_class(*[row[c] for c in columns]) + populators["expire"].append((self.key, True)) - return new_execute, None -log.class_logger(CompositeColumnLoader) - +@log.class_logger +@properties.ColumnProperty.strategy_for(deferred=True, instrument=True) class DeferredColumnLoader(LoaderStrategy): - """Strategize the loading of a deferred column-based MapperProperty.""" + """Provide loading behavior for a deferred :class:`.ColumnProperty`.""" - def create_row_processor(self, selectcontext, path, mapper, row, adapter): - col = self.columns[0] - if adapter: - col = adapter.columns[col] + __slots__ = 'columns', 'group' - key = self.key - if col in row: - return self.parent_property._get_strategy(ColumnLoader).\ - create_row_processor( - selectcontext, path, mapper, row, adapter) - - elif not self.is_class_level: - def new_execute(state, dict_, row): - state.set_callable(dict_, key, LoadDeferredColumns(state, key)) - else: - def new_execute(state, dict_, row): - # reset state on the key so that deferred callables - # fire off on next access. - state.reset(dict_, key) - - return new_execute, None - - def init(self): + def __init__(self, parent, strategy_key): + super(DeferredColumnLoader, self).__init__(parent, strategy_key) if hasattr(self.parent_property, 'composite_class'): raise NotImplementedError("Deferred loading for composite " - "types not implemented yet") + "types not implemented yet") self.columns = self.parent_property.columns self.group = self.parent_property.group + def create_row_processor( + self, context, path, loadopt, + mapper, result, adapter, populators): + + # this path currently does not check the result + # for the column; this is because in most cases we are + # working just with the setup_query() directive which does + # not support this, and the behavior here should be consistent. + if not self.is_class_level: + set_deferred_for_local_state = \ + self.parent_property._deferred_column_loader + populators["new"].append((self.key, set_deferred_for_local_state)) + else: + populators["expire"].append((self.key, False)) + def init_class_attribute(self, mapper): self.is_class_level = True - - _register_attribute(self, mapper, useobject=False, - compare_function=self.columns[0].type.compare_values, - copy_function=self.columns[0].type.copy_value, - mutable_scalars=self.columns[0].type.is_mutable(), - callable_=self._class_level_loader, - expire_missing=False + + _register_attribute( + self.parent_property, mapper, useobject=False, + compare_function=self.columns[0].type.compare_values, + callable_=self._load_for_state, + expire_missing=False ) - def setup_query(self, context, entity, path, adapter, - only_load_props=None, **kwargs): + def setup_query( + self, context, entity, path, loadopt, + adapter, column_collection, memoized_populators, + only_load_props=None, **kw): + if ( - self.group is not None and - context.attributes.get(('undefer', self.group), False) - ) or (only_load_props and self.key in only_load_props): - self.parent_property._get_strategy(ColumnLoader).\ - setup_query(context, entity, - path, adapter, **kwargs) - - def _class_level_loader(self, state): - if not mapperutil._state_has_identity(state): - return None - - return LoadDeferredColumns(state, self.key) - - -log.class_logger(DeferredColumnLoader) + ( + loadopt and + 'undefer_pks' in loadopt.local_opts and + set(self.columns).intersection( + self.parent._should_undefer_in_wildcard) + ) + or + ( + loadopt and + self.group and + loadopt.local_opts.get('undefer_group_%s' % self.group, False) + ) + or + ( + only_load_props and self.key in only_load_props + ) + ): + self.parent_property._get_strategy( + (("deferred", False), ("instrument", True)) + ).setup_query( + context, entity, + path, loadopt, adapter, + column_collection, memoized_populators, **kw) + elif self.is_class_level: + memoized_populators[self.parent_property] = _SET_DEFERRED_EXPIRED + else: + memoized_populators[self.parent_property] = _DEFER_FOR_STATE -class LoadDeferredColumns(object): - """serializable loader object used by DeferredColumnLoader""" - - def __init__(self, state, key): - self.state, self.key = state, key + def _load_for_state(self, state, passive): + if not state.key: + return attributes.ATTR_EMPTY - def __call__(self, **kw): - if kw.get('passive') is attributes.PASSIVE_NO_FETCH: + if not passive & attributes.SQL_OK: return attributes.PASSIVE_NO_RESULT - state = self.state - - localparent = mapper._state_mapper(state) - - prop = localparent.get_property(self.key) - strategy = prop._get_strategy(DeferredColumnLoader) + localparent = state.manager.mapper - if strategy.group: + if self.group: toload = [ - p.key for p in - localparent.iterate_properties - if isinstance(p, StrategizedProperty) and - isinstance(p.strategy, DeferredColumnLoader) and - p.group==strategy.group - ] + p.key for p in + localparent.iterate_properties + if isinstance(p, StrategizedProperty) and + isinstance(p.strategy, DeferredColumnLoader) and + p.group == self.group + ] else: toload = [self.key] # narrow the keys down to just those which have no history group = [k for k in toload if k in state.unmodified] - if strategy._should_log_debug(): - strategy.logger.debug( - "deferred load %s group %s", - (mapperutil.state_attribute_str(state, self.key), - group and ','.join(group) or 'None') - ) - - session = sessionlib._state_session(state) + session = _state_session(state) if session is None: raise orm_exc.DetachedInstanceError( "Parent instance %s is not bound to a Session; " - "deferred load operation of attribute '%s' cannot proceed" % - (mapperutil.state_str(state), self.key) - ) + "deferred load operation of attribute '%s' cannot proceed" % + (orm_util.state_str(state), self.key) + ) query = session.query(localparent) - ident = state.key[1] - query._get(None, ident=ident, - only_load_props=group, refresh_state=state) + if loading.load_on_ident( + query, state.key, + only_load_props=group, refresh_state=state) is None: + raise orm_exc.ObjectDeletedError(state) + return attributes.ATTR_WAS_SET -class DeferredOption(StrategizedOption): - propagate_to_loaders = True - - def __init__(self, key, defer=False): - super(DeferredOption, self).__init__(key) - self.defer = defer - def get_strategy_class(self): - if self.defer: - return DeferredColumnLoader - else: - return ColumnLoader +class LoadDeferredColumns(object): + """serializable loader object used by DeferredColumnLoader""" -class UndeferGroupOption(MapperOption): - propagate_to_loaders = True + def __init__(self, key): + self.key = key + + def __call__(self, state, passive=attributes.PASSIVE_OFF): + key = self.key + + localparent = state.manager.mapper + prop = localparent._props[key] + strategy = prop._strategies[DeferredColumnLoader] + return strategy._load_for_state(state, passive) - def __init__(self, group): - self.group = group - - def process_query(self, query): - query._attributes[('undefer', self.group)] = True class AbstractRelationshipLoader(LoaderStrategy): """LoaderStratgies which deal with related objects.""" - def init(self): + __slots__ = 'mapper', 'target', 'uselist' + + def __init__(self, parent, strategy_key): + super(AbstractRelationshipLoader, self).__init__(parent, strategy_key) self.mapper = self.parent_property.mapper self.target = self.parent_property.target - self.table = self.parent_property.table self.uselist = self.parent_property.uselist + +@log.class_logger +@properties.RelationshipProperty.strategy_for(lazy="noload") +@properties.RelationshipProperty.strategy_for(lazy=None) class NoLoader(AbstractRelationshipLoader): - """Strategize a relationship() that doesn't load data automatically.""" + """Provide loading behavior for a :class:`.RelationshipProperty` + with "lazy=None". + + """ + + __slots__ = () def init_class_attribute(self, mapper): self.is_class_level = True - _register_attribute(self, mapper, - useobject=True, - uselist=self.parent_property.uselist, - typecallable = self.parent_property.collection_class, + _register_attribute( + self.parent_property, mapper, + useobject=True, + typecallable=self.parent_property.collection_class, ) - def create_row_processor(self, selectcontext, path, mapper, row, adapter): - def new_execute(state, dict_, row): - state.initialize(self.key) - return new_execute, None + def create_row_processor( + self, context, path, loadopt, mapper, + result, adapter, populators): + def invoke_no_load(state, dict_, row): + if self.uselist: + state.manager.get_impl(self.key).initialize(state, dict_) + else: + dict_[self.key] = None + populators["new"].append((self.key, invoke_no_load)) -log.class_logger(NoLoader) - -class LazyLoader(AbstractRelationshipLoader): - """Strategize a relationship() that loads when first accessed.""" - def init(self): - super(LazyLoader, self).init() - self.__lazywhere, \ - self.__bind_to_col, \ - self._equated_columns = self._create_lazy_clause(self.parent_property) - - self.logger.info("%s lazy loading clause %s", self, self.__lazywhere) +@log.class_logger +@properties.RelationshipProperty.strategy_for(lazy=True) +@properties.RelationshipProperty.strategy_for(lazy="select") +@properties.RelationshipProperty.strategy_for(lazy="raise") +@properties.RelationshipProperty.strategy_for(lazy="raise_on_sql") +class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): + """Provide loading behavior for a :class:`.RelationshipProperty` + with "lazy=True", that is loads when first accessed. + + """ + + __slots__ = ( + '_lazywhere', '_rev_lazywhere', 'use_get', '_bind_to_col', + '_equated_columns', '_rev_bind_to_col', '_rev_equated_columns', + '_simple_lazy_clause', '_raise_always', '_raise_on_sql') + + def __init__(self, parent, strategy_key): + super(LazyLoader, self).__init__(parent, strategy_key) + self._raise_always = self.strategy_opts["lazy"] == "raise" + self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql" + + join_condition = self.parent_property._join_condition + self._lazywhere, \ + self._bind_to_col, \ + self._equated_columns = join_condition.create_lazy_clause() + + self._rev_lazywhere, \ + self._rev_bind_to_col, \ + self._rev_equated_columns = join_condition.create_lazy_clause( + reverse_direction=True) + + self.logger.info("%s lazy loading clause %s", self, self._lazywhere) # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() - #from sqlalchemy.orm import query self.use_get = not self.uselist and \ - self.mapper._get_clause[0].compare( - self.__lazywhere, - use_proxies=True, - equivalents=self.mapper._equivalent_columns - ) - + self.mapper._get_clause[0].compare( + self._lazywhere, + use_proxies=True, + equivalents=self.mapper._equivalent_columns + ) + if self.use_get: - for col in self._equated_columns.keys(): + for col in list(self._equated_columns): if col in self.mapper._equivalent_columns: for c in self.mapper._equivalent_columns[col]: self._equated_columns[c] = self._equated_columns[col] - + self.logger.info("%s will use query.get() to " - "optimize instance loads" % self) + "optimize instance loads", self) def init_class_attribute(self, mapper): self.is_class_level = True - - # MANYTOONE currently only needs the + + active_history = ( + self.parent_property.active_history or + self.parent_property.direction is not interfaces.MANYTOONE or + not self.use_get + ) + + # MANYTOONE currently only needs the # "old" value for delete-orphan - # cascades. the required _SingleParentValidator + # cascades. the required _SingleParentValidator # will enable active_history - # in that case. otherwise we don't need the + # in that case. otherwise we don't need the # "old" value during backref operations. - _register_attribute(self, - mapper, - useobject=True, - callable_=self._class_level_loader, - uselist = self.parent_property.uselist, - typecallable = self.parent_property.collection_class, - active_history = \ - self.parent_property.direction is not \ - interfaces.MANYTOONE or \ - not self.use_get, - ) + _register_attribute( + self.parent_property, + mapper, + useobject=True, + callable_=self._load_for_state, + typecallable=self.parent_property.collection_class, + active_history=active_history + ) - def lazy_clause(self, state, reverse_direction=False, - alias_secondary=False, adapt_source=None): - if state is None: - return self._lazy_none_clause( - reverse_direction, - adapt_source=adapt_source) - - if not reverse_direction: - criterion, bind_to_col, rev = \ - self.__lazywhere, \ - self.__bind_to_col, \ - self._equated_columns - else: - criterion, bind_to_col, rev = \ - LazyLoader._create_lazy_clause( - self.parent_property, - reverse_direction=reverse_direction) + def _memoized_attr__simple_lazy_clause(self): + criterion, bind_to_col = ( + self._lazywhere, + self._bind_to_col + ) - if reverse_direction: - mapper = self.parent_property.mapper - else: - mapper = self.parent_property.parent + params = [] def visit_bindparam(bindparam): - if bindparam.key in bind_to_col: - # use the "committed" (database) version to get - # query column values - # also its a deferred value; so that when used - # by Query, the committed value is used - # after an autoflush occurs - o = state.obj() # strong ref - bindparam.value = \ - lambda: mapper._get_committed_attr_by_column( - o, bind_to_col[bindparam.key]) - - if self.parent_property.secondary is not None and alias_secondary: - criterion = sql_util.ClauseAdapter( - self.parent_property.secondary.alias()).\ - traverse(criterion) + bindparam.unique = False + if bindparam._identifying_key in bind_to_col: + params.append(( + bindparam.key, bind_to_col[bindparam._identifying_key], + None)) + elif bindparam.callable is None: + params.append((bindparam.key, None, bindparam.value)) criterion = visitors.cloned_traverse( - criterion, {}, {'bindparam':visit_bindparam}) - if adapt_source: - criterion = adapt_source(criterion) - return criterion - - def _lazy_none_clause(self, reverse_direction=False, adapt_source=None): - if not reverse_direction: - criterion, bind_to_col, rev = \ - self.__lazywhere, \ - self.__bind_to_col,\ - self._equated_columns - else: - criterion, bind_to_col, rev = \ - LazyLoader._create_lazy_clause( - self.parent_property, - reverse_direction=reverse_direction) + criterion, {}, {'bindparam': visit_bindparam} + ) - criterion = sql_util.adapt_criterion_to_null(criterion, bind_to_col) + return criterion, params - if adapt_source: - criterion = adapt_source(criterion) - return criterion - - def _class_level_loader(self, state): - if not mapperutil._state_has_identity(state): - return None + def _generate_lazy_clause(self, state, passive): + criterion, param_keys = self._simple_lazy_clause - return LoadLazyAttribute(state, self.key) + if state is None: + return sql_util.adapt_criterion_to_null( + criterion, [key for key, ident, value in param_keys]) - def create_row_processor(self, selectcontext, path, mapper, row, adapter): - key = self.key - if not self.is_class_level: - def new_execute(state, dict_, row): - # we are not the primary manager for this attribute - # on this class - set up a - # per-instance lazyloader, which will override the - # class-level behavior. - # this currently only happens when using a - # "lazyload" option on a "no load" - # attribute - "eager" attributes always have a - # class-level lazyloader installed. - state.set_callable(dict_, key, LoadLazyAttribute(state, key)) - else: - def new_execute(state, dict_, row): - # we are the primary manager for this attribute on - # this class - reset its - # per-instance attribute state, so that the class-level - # lazy loader is - # executed when next referenced on this instance. - # this is needed in - # populate_existing() types of scenarios to reset - # any existing state. - state.reset(dict_, key) + mapper = self.parent_property.parent - return new_execute, None - - @classmethod - def _create_lazy_clause(cls, prop, reverse_direction=False): - binds = util.column_dict() - lookup = util.column_dict() - equated_columns = util.column_dict() + o = state.obj() # strong ref + dict_ = attributes.instance_dict(o) - if reverse_direction and prop.secondaryjoin is None: - for l, r in prop.local_remote_pairs: - _list = lookup.setdefault(r, []) - _list.append((r, l)) - equated_columns[l] = r - else: - for l, r in prop.local_remote_pairs: - _list = lookup.setdefault(l, []) - _list.append((l, r)) - equated_columns[r] = l - - def col_to_bind(col): - if col in lookup: - for tobind, equated in lookup[col]: - if equated in binds: - return None - if col not in binds: - binds[col] = sql.bindparam(None, None, type_=col.type) - return binds[col] - return None - - lazywhere = prop.primaryjoin + if passive & attributes.INIT_OK: + passive ^= attributes.INIT_OK - if prop.secondaryjoin is None or not reverse_direction: - lazywhere = visitors.replacement_traverse( - lazywhere, {}, col_to_bind) - - if prop.secondaryjoin is not None: - secondaryjoin = prop.secondaryjoin - if reverse_direction: - secondaryjoin = visitors.replacement_traverse( - secondaryjoin, {}, col_to_bind) - lazywhere = sql.and_(lazywhere, secondaryjoin) - - bind_to_col = dict((binds[col].key, col) for col in binds) - - return lazywhere, bind_to_col, equated_columns - -log.class_logger(LazyLoader) + params = {} + for key, ident, value in param_keys: + if ident is not None: + if passive and passive & attributes.LOAD_AGAINST_COMMITTED: + value = mapper._get_committed_state_attr_by_column( + state, dict_, ident, passive) + else: + value = mapper._get_state_attr_by_column( + state, dict_, ident, passive) -class LoadLazyAttribute(object): - """serializable loader object used by LazyLoader""" + params[key] = value - def __init__(self, state, key): - self.state, self.key = state, key - - def __getstate__(self): - return (self.state, self.key) + return criterion, params - def __setstate__(self, state): - self.state, self.key = state - - def __call__(self, **kw): - state = self.state - instance_mapper = mapper._state_mapper(state) - prop = instance_mapper.get_property(self.key) - strategy = prop._get_strategy(LazyLoader) + def _invoke_raise_load(self, state, passive, lazy): + raise sa_exc.InvalidRequestError( + "'%s' is not available due to lazy='%s'" % (self, lazy) + ) - if kw.get('passive') is attributes.PASSIVE_NO_FETCH and \ - not strategy.use_get: + def _load_for_state(self, state, passive): + + if not state.key and ( + ( + not self.parent_property.load_on_pending + and not state._load_pending + ) + or not state.session_id + ): + return attributes.ATTR_EMPTY + + pending = not state.key + ident_key = None + + if ( + (not passive & attributes.SQL_OK and not self.use_get) + or + (not passive & attributes.NON_PERSISTENT_OK and pending) + ): return attributes.PASSIVE_NO_RESULT - if strategy._should_log_debug(): - strategy.logger.debug("loading %s", - mapperutil.state_attribute_str( - state, self.key)) - - session = sessionlib._state_session(state) - if session is None: + if self._raise_always: + self._invoke_raise_load(state, passive, "raise") + + session = _state_session(state) + if not session: raise orm_exc.DetachedInstanceError( "Parent instance %s is not bound to a Session; " - "lazy load operation of attribute '%s' cannot proceed" % - (mapperutil.state_str(state), self.key) + "lazy load operation of attribute '%s' cannot proceed" % + (orm_util.state_str(state), self.key) ) - - q = session.query(prop.mapper)._adapt_all_clauses() - - if state.load_path: - q = q._with_current_path(state.load_path + (self.key,)) - - # if we have a simple primary key load, use mapper.get() - # to possibly save a DB round trip - if strategy.use_get: - ident = [] - allnulls = True - for primary_key in prop.mapper.primary_key: - val = instance_mapper.\ - _get_committed_state_attr_by_column( - state, - strategy._equated_columns[primary_key], - **kw) - if val is attributes.PASSIVE_NO_RESULT: - return val - allnulls = allnulls and val is None - ident.append(val) - - if allnulls: + + # if we have a simple primary key load, check the + # identity map without generating a Query at all + if self.use_get: + ident = self._get_ident_for_use_get( + session, + state, + passive + ) + if attributes.PASSIVE_NO_RESULT in ident: + return attributes.PASSIVE_NO_RESULT + elif attributes.NEVER_SET in ident: + return attributes.NEVER_SET + + if _none_set.issuperset(ident): return None - - if state.load_options: - q = q._conditional_options(*state.load_options) - key = prop.mapper.identity_key_from_primary_key(ident) - return q._get(key, ident, **kw) - + ident_key = self.mapper.identity_key_from_primary_key(ident) + instance = loading.get_from_identity(session, ident_key, passive) + if instance is not None: + return instance + elif not passive & attributes.SQL_OK or \ + not passive & attributes.RELATED_OBJECT_OK: + return attributes.PASSIVE_NO_RESULT - if prop.order_by: - q = q.order_by(*util.to_list(prop.order_by)) + return self._emit_lazyload(session, state, ident_key, passive) + + def _get_ident_for_use_get(self, session, state, passive): + instance_mapper = state.manager.mapper + + if passive & attributes.LOAD_AGAINST_COMMITTED: + get_attr = instance_mapper._get_committed_state_attr_by_column + else: + get_attr = instance_mapper._get_state_attr_by_column + + dict_ = state.dict + + return [ + get_attr( + state, + dict_, + self._equated_columns[pk], + passive=passive) + for pk in self.mapper.primary_key + ] + + @util.dependencies("sqlalchemy.orm.strategy_options") + def _emit_lazyload( + self, strategy_options, session, state, ident_key, passive): + + q = session.query(self.mapper)._adapt_all_clauses() + if self.parent_property.secondary is not None: + q = q.select_from(self.mapper, self.parent_property.secondary) + + q = q._with_invoke_all_eagers(False) + + pending = not state.key + + # don't autoflush on pending + if pending or passive & attributes.NO_AUTOFLUSH: + q = q.autoflush(False) + + if state.load_path: + q = q._with_current_path(state.load_path[self.parent_property]) if state.load_options: q = q._conditional_options(*state.load_options) - q = q.filter(strategy.lazy_clause(state)) + + if self.use_get: + if self._raise_on_sql: + self._invoke_raise_load(state, passive, "raise_on_sql") + return loading.load_on_ident(q, ident_key) + + if self.parent_property.order_by: + q = q.order_by(*util.to_list(self.parent_property.order_by)) + + for rev in self.parent_property._reverse_property: + # reverse props that are MANYTOONE are loading *this* + # object from get(), so don't need to eager out to those. + if rev.direction is interfaces.MANYTOONE and \ + rev._use_get and \ + not isinstance(rev.strategy, LazyLoader): + q = q.options( + strategy_options.Load.for_existing_path( + q._current_path[rev.parent] + ).lazyload(rev.key) + ) + + lazy_clause, params = self._generate_lazy_clause( + state, passive=passive) + + if pending: + if util.has_intersection( + orm_util._none_set, params.values()): + return None + elif util.has_intersection(orm_util._never_set, params.values()): + return None + + if self._raise_on_sql: + self._invoke_raise_load(state, passive, "raise_on_sql") + + q = q.filter(lazy_clause).params(params) result = q.all() - if strategy.uselist: + if self.uselist: return result else: l = len(result) @@ -624,75 +641,240 @@ class LoadLazyAttribute(object): if l > 1: util.warn( "Multiple rows returned with " - "uselist=False for lazily-loaded attribute '%s' " - % prop) - + "uselist=False for lazily-loaded attribute '%s' " + % self.parent_property) + return result[0] else: return None -class SubqueryLoader(AbstractRelationshipLoader): - def init(self): - super(SubqueryLoader, self).init() - self.join_depth = self.parent_property.join_depth - + def create_row_processor( + self, context, path, loadopt, + mapper, result, adapter, populators): + key = self.key + if not self.is_class_level: + # we are not the primary manager for this attribute + # on this class - set up a + # per-instance lazyloader, which will override the + # class-level behavior. + # this currently only happens when using a + # "lazyload" option on a "no load" + # attribute - "eager" attributes always have a + # class-level lazyloader installed. + set_lazy_callable = InstanceState._instance_level_callable_processor( + mapper.class_manager, + LoadLazyAttribute(key, self), key) + + populators["new"].append((self.key, set_lazy_callable)) + elif context.populate_existing or mapper.always_refresh: + def reset_for_lazy_callable(state, dict_, row): + # we are the primary manager for this attribute on + # this class - reset its + # per-instance attribute state, so that the class-level + # lazy loader is + # executed when next referenced on this instance. + # this is needed in + # populate_existing() types of scenarios to reset + # any existing state. + state._reset(dict_, key) + + populators["new"].append((self.key, reset_for_lazy_callable)) + + +class LoadLazyAttribute(object): + """serializable loader object used by LazyLoader""" + + def __init__(self, key, initiating_strategy): + self.key = key + self.strategy_key = initiating_strategy.strategy_key + + def __call__(self, state, passive=attributes.PASSIVE_OFF): + key = self.key + instance_mapper = state.manager.mapper + prop = instance_mapper._props[key] + strategy = prop._strategies[self.strategy_key] + + return strategy._load_for_state(state, passive) + + +@properties.RelationshipProperty.strategy_for(lazy="immediate") +class ImmediateLoader(AbstractRelationshipLoader): + __slots__ = () + def init_class_attribute(self, mapper): self.parent_property.\ - _get_strategy(LazyLoader).\ - init_class_attribute(mapper) - - def setup_query(self, context, entity, - path, adapter, column_collection=None, - parentmapper=None, **kwargs): + _get_strategy((("lazy", "select"),)).\ + init_class_attribute(mapper) + + def setup_query( + self, context, entity, + path, loadopt, adapter, column_collection=None, + parentmapper=None, **kwargs): + pass + + def create_row_processor( + self, context, path, loadopt, + mapper, result, adapter, populators): + def load_immediate(state, dict_, row): + state.get_impl(self.key).get(state, dict_) + + populators["delayed"].append((self.key, load_immediate)) + + +@log.class_logger +@properties.RelationshipProperty.strategy_for(lazy="subquery") +class SubqueryLoader(AbstractRelationshipLoader): + __slots__ = 'join_depth', + + def __init__(self, parent, strategy_key): + super(SubqueryLoader, self).__init__(parent, strategy_key) + self.join_depth = self.parent_property.join_depth + + def init_class_attribute(self, mapper): + self.parent_property.\ + _get_strategy((("lazy", "select"),)).\ + init_class_attribute(mapper) + + def setup_query( + self, context, entity, + path, loadopt, adapter, + column_collection=None, + parentmapper=None, **kwargs): if not context.query._enable_eagerloads: return - - path = path + (self.key, ) + elif context.query._yield_per: + context.query._no_yield_per("subquery") + + path = path[self.parent_property] # build up a path indicating the path from the leftmost # entity to the thing we're subquery loading. - subq_path = context.attributes.get(('subquery_path', None), ()) + with_poly_info = path.get( + context.attributes, + "path_with_polymorphic", None) + if with_poly_info is not None: + effective_entity = with_poly_info.entity + else: + effective_entity = self.mapper + + subq_path = context.attributes.get( + ('subquery_path', None), + orm_util.PathRegistry.root) subq_path = subq_path + path - reduced_path = interfaces._reduce_path(path) - - # join-depth / recursion check - if ("loaderstrategy", reduced_path) not in context.attributes: + # if not via query option, check for + # a cycle + if not path.contains(context.attributes, "loader"): if self.join_depth: - if len(path) / 2 > self.join_depth: + if path.length / 2 > self.join_depth: return - else: - if self.mapper.base_mapper in interfaces._reduce_path(subq_path): - return - + elif subq_path.contains_mapper(self.mapper): + return + + leftmost_mapper, leftmost_attr, leftmost_relationship = \ + self._get_leftmost(subq_path) + orig_query = context.attributes.get( - ("orig_query", SubqueryLoader), - context.query) + ("orig_query", SubqueryLoader), + context.query) + + # generate a new Query from the original, then + # produce a subquery from it. + left_alias = self._generate_from_original_query( + orig_query, leftmost_mapper, + leftmost_attr, leftmost_relationship, + entity.entity_zero + ) + + # generate another Query that will join the + # left alias to the target relationships. + # basically doing a longhand + # "from_self()". (from_self() itself not quite industrial + # strength enough for all contingencies...but very close) + q = orig_query.session.query(effective_entity) + q._attributes = { + ("orig_query", SubqueryLoader): orig_query, + ('subquery_path', None): subq_path + } + + q = q._set_enable_single_crit(False) + to_join, local_attr, parent_alias = \ + self._prep_for_joins(left_alias, subq_path) + q = q.order_by(*local_attr) + q = q.add_columns(*local_attr) + q = self._apply_joins( + q, to_join, left_alias, + parent_alias, effective_entity) + + q = self._setup_options(q, subq_path, orig_query, effective_entity) + q = self._setup_outermost_orderby(q) + + # add new query to attributes to be picked up + # by create_row_processor + path.set(context.attributes, "subquery", q) + + def _get_leftmost(self, subq_path): + subq_path = subq_path.path + subq_mapper = orm_util._class_to_mapper(subq_path[0]) # determine attributes of the leftmost mapper - if self.parent.isa(subq_path[0]) and self.key==subq_path[1]: + if self.parent.isa(subq_mapper) and \ + self.parent_property is subq_path[1]: leftmost_mapper, leftmost_prop = \ - self.parent, self.parent_property + self.parent, self.parent_property else: leftmost_mapper, leftmost_prop = \ - subq_path[0], \ - subq_path[0].get_property(subq_path[1]) - leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop) - + subq_mapper, \ + subq_path[1] + + leftmost_cols = leftmost_prop.local_columns + leftmost_attr = [ - leftmost_mapper._get_col_to_prop(c).class_attribute + getattr( + subq_path[0].entity, + leftmost_mapper._columntoproperty[c].key) for c in leftmost_cols ] + return leftmost_mapper, leftmost_attr, leftmost_prop + + def _generate_from_original_query( + self, + orig_query, leftmost_mapper, + leftmost_attr, leftmost_relationship, orig_entity + ): # reformat the original query # to look only for significant columns - q = orig_query._clone() - # TODO: why does polymporphic etc. require hardcoding - # into _adapt_col_list ? Does query.add_columns(...) work - # with polymorphic loading ? - q._set_entities(q._adapt_col_list(leftmost_attr)) + q = orig_query._clone().correlate(None) + + # set a real "from" if not present, as this is more + # accurate than just going off of the column expression + if not q._from_obj and orig_entity.is_mapper and \ + orig_entity.mapper.isa(leftmost_mapper): + q._set_select_from([orig_entity], False) + target_cols = q._adapt_col_list(leftmost_attr) + + # select from the identity columns of the outer. This will remove + # other columns from the query that might suggest the right entity + # which is why we try to _set_select_from above. + q._set_entities(target_cols) + + distinct_target_key = leftmost_relationship.distinct_target_key + + if distinct_target_key is True: + q._distinct = True + elif distinct_target_key is None: + # if target_cols refer to a non-primary key or only + # part of a composite primary key, set the q as distinct + for t in set(c.table for c in target_cols): + if not set(target_cols).issuperset(t.primary_key): + q._distinct = True + break + + if q._order_by is False: + q._order_by = leftmost_mapper.order_by # don't need ORDER BY if no limit/offset if q._limit is None and q._offset is None: @@ -700,261 +882,522 @@ class SubqueryLoader(AbstractRelationshipLoader): # the original query now becomes a subquery # which we'll join onto. - embed_q = q.with_labels().subquery() - left_alias = mapperutil.AliasedClass(leftmost_mapper, embed_q) - - # q becomes a new query. basically doing a longhand - # "from_self()". (from_self() itself not quite industrial - # strength enough for all contingencies...but very close) - - q = q.session.query(self.mapper) - q._attributes = { - ("orig_query", SubqueryLoader): orig_query, - ('subquery_path', None) : subq_path - } + embed_q = q.with_labels().subquery() + left_alias = orm_util.AliasedClass( + leftmost_mapper, embed_q, + use_mapper_path=True) + return left_alias + + def _prep_for_joins(self, left_alias, subq_path): # figure out what's being joined. a.k.a. the fun part - to_join = [ - (subq_path[i], subq_path[i+1]) - for i in xrange(0, len(subq_path), 2) - ] + to_join = [] + pairs = list(subq_path.pairs()) + + for i, (mapper, prop) in enumerate(pairs): + if i > 0: + # look at the previous mapper in the chain - + # if it is as or more specific than this prop's + # mapper, use that instead. + # note we have an assumption here that + # the non-first element is always going to be a mapper, + # not an AliasedClass + + prev_mapper = pairs[i - 1][1].mapper + to_append = prev_mapper if prev_mapper.isa(mapper) else mapper + else: + to_append = mapper + + to_join.append((to_append, prop.key)) + + # determine the immediate parent class we are joining from, + # which needs to be aliased. if len(to_join) < 2: + # in the case of a one level eager load, this is the + # leftmost "left_alias". parent_alias = left_alias else: - parent_alias = mapperutil.AliasedClass(self.parent) + info = inspect(to_join[-1][0]) + if info.is_aliased_class: + parent_alias = info.entity + else: + # alias a plain mapper as we may be + # joining multiple times + parent_alias = orm_util.AliasedClass( + info.entity, + use_mapper_path=True) - local_cols, remote_cols = \ - self._local_remote_columns(self.parent_property) + local_cols = self.parent_property.local_columns local_attr = [ - getattr(parent_alias, self.parent._get_col_to_prop(c).key) + getattr(parent_alias, self.parent._columntoproperty[c].key) for c in local_cols ] - q = q.order_by(*local_attr) - q = q.add_columns(*local_attr) - - for i, (mapper, key) in enumerate(to_join): - - # we need to use query.join() as opposed to - # orm.join() here because of the - # rich behavior it brings when dealing with - # "with_polymorphic" mappers. "aliased" - # and "from_joinpoint" take care of most of - # the chaining and aliasing for us. - - first = i == 0 - middle = i < len(to_join) - 1 - second_to_last = i == len(to_join) - 2 - - if first: - attr = getattr(left_alias, key) - else: - attr = key - - if second_to_last: - q = q.join((parent_alias, attr), from_joinpoint=True) - else: - q = q.join(attr, aliased=middle, from_joinpoint=True) + return to_join, local_attr, parent_alias + def _apply_joins( + self, q, to_join, left_alias, parent_alias, + effective_entity): + + ltj = len(to_join) + if ltj == 1: + to_join = [ + getattr(left_alias, to_join[0][1]).of_type(effective_entity) + ] + elif ltj == 2: + to_join = [ + getattr(left_alias, to_join[0][1]).of_type(parent_alias), + getattr(parent_alias, to_join[-1][1]).of_type(effective_entity) + ] + elif ltj > 2: + middle = [ + ( + orm_util.AliasedClass(item[0]) + if not inspect(item[0]).is_aliased_class + else item[0].entity, + item[1] + ) for item in to_join[1:-1] + ] + inner = [] + + while middle: + item = middle.pop(0) + attr = getattr(item[0], item[1]) + if middle: + attr = attr.of_type(middle[0][0]) + else: + attr = attr.of_type(parent_alias) + + inner.append(attr) + + to_join = [ + getattr(left_alias, to_join[0][1]).of_type(inner[0].parent) + ] + inner + [ + getattr(parent_alias, to_join[-1][1]).of_type(effective_entity) + ] + + for attr in to_join: + q = q.join(attr, from_joinpoint=True) + return q + + def _setup_options(self, q, subq_path, orig_query, effective_entity): # propagate loader options etc. to the new query. # these will fire relative to subq_path. q = q._with_current_path(subq_path) q = q._conditional_options(*orig_query._with_options) + if orig_query._populate_existing: + q._populate_existing = orig_query._populate_existing + return q + + def _setup_outermost_orderby(self, q): if self.parent_property.order_by: # if there's an ORDER BY, alias it the same - # way joinedloader does, but we have to pull out + # way joinedloader does, but we have to pull out # the "eagerjoin" from the query. # this really only picks up the "secondary" table # right now. eagerjoin = q._from_obj[0] eager_order_by = \ - eagerjoin._target_adapter.\ - copy_and_process( - util.to_list( - self.parent_property.order_by - ) - ) + eagerjoin._target_adapter.\ + copy_and_process( + util.to_list( + self.parent_property.order_by + ) + ) q = q.order_by(*eager_order_by) - - # add new query to attributes to be picked up - # by create_row_processor - context.attributes[('subquery', reduced_path)] = q - - def _local_remote_columns(self, prop): - if prop.secondary is None: - return zip(*prop.local_remote_pairs) - else: - return \ - [p[0] for p in prop.synchronize_pairs],\ - [ - p[0] for p in prop. - secondary_synchronize_pairs - ] - - def create_row_processor(self, context, path, mapper, row, adapter): - path = path + (self.key,) + return q - path = interfaces._reduce_path(path) - - if ('subquery', path) not in context.attributes: - return None, None - - local_cols, remote_cols = self._local_remote_columns(self.parent_property) + class _SubqCollections(object): + """Given a :class:`.Query` used to emit the "subquery load", + provide a load interface that executes the query at the + first moment a value is needed. + + """ + _data = None + + def __init__(self, subq): + self.subq = subq + + def get(self, key, default): + if self._data is None: + self._load() + return self._data.get(key, default) + + def _load(self): + self._data = dict( + (k, [vv[0] for vv in v]) + for k, v in itertools.groupby( + self.subq, + lambda x: x[1:] + ) + ) + + def loader(self, state, dict_, row): + if self._data is None: + self._load() + + def create_row_processor( + self, context, path, loadopt, + mapper, result, adapter, populators): + if not self.parent.class_manager[self.key].impl.supports_population: + raise sa_exc.InvalidRequestError( + "'%s' does not support object " + "population - eager loading cannot be applied." % + self) + + path = path[self.parent_property] + + subq = path.get(context.attributes, 'subquery') + + if subq is None: + return + + assert subq.session is context.session, ( + "Subquery session doesn't refer to that of " + "our context. Are there broken context caching " + "schemes being used?" + ) + + local_cols = self.parent_property.local_columns + + # cache the loaded collections in the context + # so that inheriting mappers don't re-load when they + # call upon create_row_processor again + collections = path.get(context.attributes, "collections") + if collections is None: + collections = self._SubqCollections(subq) + path.set(context.attributes, 'collections', collections) - remote_attr = [ - self.mapper._get_col_to_prop(c).key - for c in remote_cols] - - q = context.attributes[('subquery', path)] - - collections = dict( - (k, [v[0] for v in v]) - for k, v in itertools.groupby( - q, - lambda x:x[1:] - )) - if adapter: local_cols = [adapter.columns[c] for c in local_cols] - + if self.uselist: - def execute(state, dict_, row): - collection = collections.get( - tuple([row[col] for col in local_cols]), - () - ) - state.get_impl(self.key).\ - set_committed_value(state, dict_, collection) + self._create_collection_loader( + context, collections, local_cols, populators) else: - def execute(state, dict_, row): - collection = collections.get( - tuple([row[col] for col in local_cols]), - (None,) - ) - if len(collection) > 1: - util.warn( - "Multiple rows returned with " - "uselist=False for eagerly-loaded attribute '%s' " - % self) - - scalar = collection[0] - state.get_impl(self.key).\ - set_committed_value(state, dict_, scalar) - - return execute, None + self._create_scalar_loader( + context, collections, local_cols, populators) -log.class_logger(SubqueryLoader) + def _create_collection_loader( + self, context, collections, local_cols, populators): + def load_collection_from_subq(state, dict_, row): + collection = collections.get( + tuple([row[col] for col in local_cols]), + () + ) + state.get_impl(self.key).\ + set_committed_value(state, dict_, collection) -class EagerLoader(AbstractRelationshipLoader): - """Strategize a relationship() that loads within the process - of the parent object being selected.""" - - def init(self): - super(EagerLoader, self).init() + def load_collection_from_subq_existing_row(state, dict_, row): + if self.key not in dict_: + load_collection_from_subq(state, dict_, row) + + populators["new"].append( + (self.key, load_collection_from_subq)) + populators["existing"].append( + (self.key, load_collection_from_subq_existing_row)) + + if context.invoke_all_eagers: + populators["eager"].append((self.key, collections.loader)) + + def _create_scalar_loader( + self, context, collections, local_cols, populators): + def load_scalar_from_subq(state, dict_, row): + collection = collections.get( + tuple([row[col] for col in local_cols]), + (None,) + ) + if len(collection) > 1: + util.warn( + "Multiple rows returned with " + "uselist=False for eagerly-loaded attribute '%s' " + % self) + + scalar = collection[0] + state.get_impl(self.key).\ + set_committed_value(state, dict_, scalar) + + def load_scalar_from_subq_existing_row(state, dict_, row): + if self.key not in dict_: + load_scalar_from_subq(state, dict_, row) + + populators["new"].append( + (self.key, load_scalar_from_subq)) + populators["existing"].append( + (self.key, load_scalar_from_subq_existing_row)) + if context.invoke_all_eagers: + populators["eager"].append((self.key, collections.loader)) + + +@log.class_logger +@properties.RelationshipProperty.strategy_for(lazy="joined") +@properties.RelationshipProperty.strategy_for(lazy=False) +class JoinedLoader(AbstractRelationshipLoader): + """Provide loading behavior for a :class:`.RelationshipProperty` + using joined eager loading. + + """ + + __slots__ = 'join_depth', '_aliased_class_pool' + + def __init__(self, parent, strategy_key): + super(JoinedLoader, self).__init__(parent, strategy_key) self.join_depth = self.parent_property.join_depth + self._aliased_class_pool = [] def init_class_attribute(self, mapper): self.parent_property.\ - _get_strategy(LazyLoader).init_class_attribute(mapper) - - def setup_query(self, context, entity, path, adapter, \ - column_collection=None, parentmapper=None, - **kwargs): - """Add a left outer join to the statement thats being constructed.""" + _get_strategy((("lazy", "select"),)).init_class_attribute(mapper) + + def setup_query( + self, context, entity, path, loadopt, adapter, + column_collection=None, parentmapper=None, + chained_from_outerjoin=False, + **kwargs): + """Add a left outer join to the statement that's being constructed.""" if not context.query._enable_eagerloads: return - - path = path + (self.key,) - - reduced_path = interfaces._reduce_path(path) - - # check for user-defined eager alias - if ("user_defined_eager_row_processor", reduced_path) in\ - context.attributes: - clauses = context.attributes[ - ("user_defined_eager_row_processor", - reduced_path)] - - adapter = entity._get_entity_clauses(context.query, context) - if adapter and clauses: - context.attributes[ - ("user_defined_eager_row_processor", - reduced_path)] = clauses = clauses.wrap(adapter) - elif adapter: - context.attributes[ - ("user_defined_eager_row_processor", - reduced_path)] = clauses = adapter - - add_to_collection = context.primary_columns - + elif context.query._yield_per and self.uselist: + context.query._no_yield_per("joined collection") + + path = path[self.parent_property] + + with_polymorphic = None + + user_defined_adapter = self._init_user_defined_eager_proc( + loadopt, context) if loadopt else False + + if user_defined_adapter is not False: + clauses, adapter, add_to_collection = \ + self._setup_query_on_user_defined_adapter( + context, entity, path, adapter, + user_defined_adapter + ) else: - # check for join_depth or basic recursion, - # if the current path was not explicitly stated as - # a desired "loaderstrategy" (i.e. via query.options()) - if ("loaderstrategy", reduced_path) not in context.attributes: + # if not via query option, check for + # a cycle + if not path.contains(context.attributes, "loader"): if self.join_depth: - if len(path) / 2 > self.join_depth: - return - else: - if self.mapper.base_mapper in reduced_path: + if path.length / 2 > self.join_depth: return + elif path.contains_mapper(self.mapper): + return - clauses = mapperutil.ORMAdapter( - mapperutil.AliasedClass(self.mapper), - equivalents=self.mapper._equivalent_columns, - adapt_required=True) + clauses, adapter, add_to_collection, chained_from_outerjoin = \ + self._generate_row_adapter( + context, entity, path, loadopt, adapter, + column_collection, parentmapper, chained_from_outerjoin + ) - if self.parent_property.direction != interfaces.MANYTOONE: - context.multi_row_eager_loaders = True + with_poly_info = path.get( + context.attributes, + "path_with_polymorphic", + None + ) + if with_poly_info is not None: + with_polymorphic = with_poly_info.with_polymorphic_mappers + else: + with_polymorphic = None - context.create_eager_joins.append( - (self._create_eager_join, context, - entity, path, adapter, - parentmapper, clauses) + path = path[self.mapper] + + loading._setup_entity_query( + context, self.mapper, entity, + path, clauses, add_to_collection, + with_polymorphic=with_polymorphic, + parentmapper=self.mapper, + chained_from_outerjoin=chained_from_outerjoin) + + if with_poly_info is not None and \ + None in set(context.secondary_columns): + raise sa_exc.InvalidRequestError( + "Detected unaliased columns when generating joined " + "load. Make sure to use aliased=True or flat=True " + "when using joined loading with with_polymorphic()." ) - add_to_collection = context.secondary_columns - context.attributes[ - ("eager_row_processor", reduced_path) - ] = clauses + def _init_user_defined_eager_proc(self, loadopt, context): + + # check if the opt applies at all + if "eager_from_alias" not in loadopt.local_opts: + # nope + return False + + path = loadopt.path.parent + + # the option applies. check if the "user_defined_eager_row_processor" + # has been built up. + adapter = path.get( + context.attributes, + "user_defined_eager_row_processor", False) + if adapter is not False: + # just return it + return adapter + + # otherwise figure it out. + alias = loadopt.local_opts["eager_from_alias"] + + root_mapper, prop = path[-2:] + + #from .mapper import Mapper + #from .interfaces import MapperProperty + #assert isinstance(root_mapper, Mapper) + #assert isinstance(prop, MapperProperty) + + if alias is not None: + if isinstance(alias, str): + alias = prop.target.alias(alias) + adapter = sql_util.ColumnAdapter( + alias, + equivalents=prop.mapper._equivalent_columns) + else: + if path.contains(context.attributes, "path_with_polymorphic"): + with_poly_info = path.get( + context.attributes, + "path_with_polymorphic") + adapter = orm_util.ORMAdapter( + with_poly_info.entity, + equivalents=prop.mapper._equivalent_columns) + else: + adapter = context.query._polymorphic_adapters.get( + prop.mapper, None) + path.set( + context.attributes, + "user_defined_eager_row_processor", + adapter) + + return adapter + + def _setup_query_on_user_defined_adapter( + self, context, entity, + path, adapter, user_defined_adapter): + + # apply some more wrapping to the "user defined adapter" + # if we are setting up the query for SQL render. + adapter = entity._get_entity_clauses(context.query, context) + + if adapter and user_defined_adapter: + user_defined_adapter = user_defined_adapter.wrap(adapter) + path.set( + context.attributes, "user_defined_eager_row_processor", + user_defined_adapter) + elif adapter: + user_defined_adapter = adapter + path.set( + context.attributes, "user_defined_eager_row_processor", + user_defined_adapter) + + add_to_collection = context.primary_columns + return user_defined_adapter, adapter, add_to_collection + + def _gen_pooled_aliased_class(self, context): + # keep a local pool of AliasedClass objects that get re-used. + # we need one unique AliasedClass per query per appearance of our + # entity in the query. + + key = ('joinedloader_ac', self) + if key not in context.attributes: + context.attributes[key] = idx = 0 + else: + context.attributes[key] = idx = context.attributes[key] + 1 + + if idx >= len(self._aliased_class_pool): + to_adapt = orm_util.AliasedClass( + self.mapper, + flat=True, + use_mapper_path=True) + # load up the .columns collection on the Alias() before + # the object becomes shared among threads. this prevents + # races for column identities. + inspect(to_adapt).selectable.c + + self._aliased_class_pool.append(to_adapt) + + return self._aliased_class_pool[idx] + + def _generate_row_adapter( + self, + context, entity, path, loadopt, adapter, + column_collection, parentmapper, chained_from_outerjoin): + with_poly_info = path.get( + context.attributes, + "path_with_polymorphic", + None + ) + if with_poly_info: + to_adapt = with_poly_info.entity + else: + to_adapt = self._gen_pooled_aliased_class(context) + + clauses = inspect(to_adapt)._memo( + ("joinedloader_ormadapter", self), + orm_util.ORMAdapter, + to_adapt, + equivalents=self.mapper._equivalent_columns, + adapt_required=True, allow_label_resolve=False, + anonymize_labels=True + ) + + assert clauses.aliased_class is not None + + if self.parent_property.uselist: + context.multi_row_eager_loaders = True + + innerjoin = ( + loadopt.local_opts.get( + 'innerjoin', self.parent_property.innerjoin) + if loadopt is not None + else self.parent_property.innerjoin + ) + + if not innerjoin: + # if this is an outer join, all non-nested eager joins from + # this path must also be outer joins + chained_from_outerjoin = True + + context.create_eager_joins.append( + ( + self._create_eager_join, context, + entity, path, adapter, + parentmapper, clauses, innerjoin, chained_from_outerjoin + ) + ) + + add_to_collection = context.secondary_columns + path.set(context.attributes, "eager_row_processor", clauses) + + return clauses, adapter, add_to_collection, chained_from_outerjoin + + def _create_eager_join( + self, context, entity, + path, adapter, parentmapper, + clauses, innerjoin, chained_from_outerjoin): - for value in self.mapper._iterate_polymorphic_properties(): - value.setup( - context, - entity, - path + (self.mapper,), - clauses, - parentmapper=self.mapper, - column_collection=add_to_collection) - - def _create_eager_join(self, context, entity, - path, adapter, parentmapper, clauses): - if parentmapper is None: localparent = entity.mapper else: localparent = parentmapper - + # whether or not the Query will wrap the selectable in a subquery, - # and then attach eager load joins to that (i.e., in the case of + # and then attach eager load joins to that (i.e., in the case of # LIMIT/OFFSET etc.) should_nest_selectable = context.multi_row_eager_loaders and \ context.query._should_nest_selectable - + entity_key = None + if entity not in context.eager_joins and \ not should_nest_selectable and \ - context.from_clause: - index, clause = \ - sql_util.find_join_source( - context.from_clause, entity.selectable) + context.from_clause: + index, clause = sql_util.find_join_source( + context.from_clause, entity.selectable) if clause is not None: # join to an existing FROM clause on the query. # key it to its list index in the eager_joins dict. - # Query._compile_context will adapt as needed and + # Query._compile_context will adapt as needed and # append to the FROM clause of the select(). entity_key, default_towrap = index, clause @@ -963,267 +1406,302 @@ class EagerLoader(AbstractRelationshipLoader): towrap = context.eager_joins.setdefault(entity_key, default_towrap) - join_to_left = False if adapter: if getattr(adapter, 'aliased_class', None): + # joining from an adapted entity. The adapted entity + # might be a "with_polymorphic", so resolve that to our + # specific mapper's entity before looking for our attribute + # name on it. + efm = inspect(adapter.aliased_class).\ + _entity_for_mapper( + localparent + if localparent.isa(self.parent) else self.parent) + + # look for our attribute on the adapted entity, else fall back + # to our straight property onclause = getattr( - adapter.aliased_class, self.key, - self.parent_property) + efm.entity, self.key, + self.parent_property) else: onclause = getattr( - mapperutil.AliasedClass( - self.parent, - adapter.selectable - ), - self.key, self.parent_property - ) - - if onclause is self.parent_property: - # TODO: this is a temporary hack to - # account for polymorphic eager loads where - # the eagerload is referencing via of_type(). - join_to_left = True + orm_util.AliasedClass( + self.parent, + adapter.selectable, + use_mapper_path=True + ), + self.key, self.parent_property + ) + else: onclause = self.parent_property - innerjoin = context.attributes.get( - ("eager_join_type", path), - self.parent_property.innerjoin) + assert clauses.aliased_class is not None - context.eager_joins[entity_key] = eagerjoin = \ - mapperutil.join( - towrap, - clauses.aliased_class, - onclause, - join_to_left=join_to_left, - isouter=not innerjoin - ) + attach_on_outside = ( + not chained_from_outerjoin or + not innerjoin or innerjoin == 'unnested') + + if attach_on_outside: + # this is the "classic" eager join case. + eagerjoin = orm_util._ORMJoin( + towrap, + clauses.aliased_class, + onclause, + isouter=not innerjoin or ( + chained_from_outerjoin and isinstance(towrap, sql.Join) + ), _left_memo=self.parent, _right_memo=self.mapper + ) + else: + # all other cases are innerjoin=='nested' approach + eagerjoin = self._splice_nested_inner_join( + path, towrap, clauses, onclause) + + context.eager_joins[entity_key] = eagerjoin # send a hint to the Query as to where it may "splice" this join eagerjoin.stop_on = entity.selectable - if self.parent_property.secondary is None and \ - not parentmapper: + if not parentmapper: # for parentclause that is the non-eager end of the join, - # ensure all the parent cols in the primaryjoin are actually + # ensure all the parent cols in the primaryjoin are actually # in the - # columns clause (i.e. are not deferred), so that aliasing applied - # by the Query propagates those columns outward. - # This has the effect + # columns clause (i.e. are not deferred), so that aliasing applied + # by the Query propagates those columns outward. + # This has the effect # of "undefering" those columns. - for col in sql_util.find_columns( - self.parent_property.primaryjoin): + for col in sql_util._find_columns( + self.parent_property.primaryjoin): if localparent.mapped_table.c.contains_column(col): if adapter: col = adapter.columns[col] context.primary_columns.append(col) - - if self.parent_property.order_by: - context.eager_order_by += \ - eagerjoin._target_adapter.\ - copy_and_process( - util.to_list( - self.parent_property.order_by - ) - ) - - def _create_eager_adapter(self, context, row, adapter, path): - reduced_path = interfaces._reduce_path(path) - if ("user_defined_eager_row_processor", reduced_path) in \ - context.attributes: - decorator = context.attributes[ - ("user_defined_eager_row_processor", - reduced_path)] - # user defined eagerloads are part of the "primary" + if self.parent_property.order_by: + context.eager_order_by += eagerjoin._target_adapter.\ + copy_and_process( + util.to_list( + self.parent_property.order_by + ) + ) + + def _splice_nested_inner_join( + self, path, join_obj, clauses, onclause, splicing=False): + + if splicing is False: + # first call is always handed a join object + # from the outside + assert isinstance(join_obj, orm_util._ORMJoin) + elif isinstance(join_obj, sql.selectable.FromGrouping): + return self._splice_nested_inner_join( + path, join_obj.element, clauses, onclause, splicing + ) + elif not isinstance(join_obj, orm_util._ORMJoin): + if path[-2] is splicing: + return orm_util._ORMJoin( + join_obj, clauses.aliased_class, + onclause, isouter=False, + _left_memo=splicing, + _right_memo=path[-1].mapper + ) + else: + # only here if splicing == True + return None + + target_join = self._splice_nested_inner_join( + path, join_obj.right, clauses, + onclause, join_obj._right_memo) + if target_join is None: + right_splice = False + target_join = self._splice_nested_inner_join( + path, join_obj.left, clauses, + onclause, join_obj._left_memo) + if target_join is None: + # should only return None when recursively called, + # e.g. splicing==True + assert splicing is not False, \ + "assertion failed attempting to produce joined eager loads" + return None + else: + right_splice = True + + if right_splice: + # for a right splice, attempt to flatten out + # a JOIN b JOIN c JOIN .. to avoid needless + # parenthesis nesting + if not join_obj.isouter and not target_join.isouter: + eagerjoin = join_obj._splice_into_center(target_join) + else: + eagerjoin = orm_util._ORMJoin( + join_obj.left, target_join, + join_obj.onclause, isouter=join_obj.isouter, + _left_memo=join_obj._left_memo) + else: + eagerjoin = orm_util._ORMJoin( + target_join, join_obj.right, + join_obj.onclause, isouter=join_obj.isouter, + _right_memo=join_obj._right_memo) + + eagerjoin._target_adapter = target_join._target_adapter + return eagerjoin + + def _create_eager_adapter(self, context, result, adapter, path, loadopt): + user_defined_adapter = self._init_user_defined_eager_proc( + loadopt, context) if loadopt else False + + if user_defined_adapter is not False: + decorator = user_defined_adapter + # user defined eagerloads are part of the "primary" # portion of the load. # the adapters applied to the Query should be honored. if context.adapter and decorator: decorator = decorator.wrap(context.adapter) elif context.adapter: decorator = context.adapter - elif ("eager_row_processor", reduced_path) in context.attributes: - decorator = context.attributes[ - ("eager_row_processor", reduced_path)] else: - return False + decorator = path.get(context.attributes, "eager_row_processor") + if decorator is None: + return False - try: - identity_key = self.mapper.identity_key_from_row(row, decorator) + if self.mapper._result_has_identity_key(result, decorator): return decorator - except KeyError, k: - # no identity key - dont return a row + else: + # no identity key - don't return a row # processor, will cause a degrade to lazy return False - def create_row_processor(self, context, path, mapper, row, adapter): - path = path + (self.key,) - + def create_row_processor( + self, context, path, loadopt, mapper, + result, adapter, populators): + if not self.parent.class_manager[self.key].impl.supports_population: + raise sa_exc.InvalidRequestError( + "'%s' does not support object " + "population - eager loading cannot be applied." % + self + ) + + our_path = path[self.parent_property] + eager_adapter = self._create_eager_adapter( - context, - row, - adapter, path) - + context, + result, + adapter, our_path, loadopt) + if eager_adapter is not False: key = self.key - _instance = self.mapper._instance_processor( - context, - path + (self.mapper,), - eager_adapter) - + + _instance = loading._instance_processor( + self.mapper, + context, + result, + our_path[self.mapper], + eager_adapter) + if not self.uselist: - def new_execute(state, dict_, row): - # set a scalar object instance directly on the parent - # object, bypassing InstrumentedAttribute event handlers. - dict_[key] = _instance(row, None) - - def existing_execute(state, dict_, row): - # call _instance on the row, even though the object has - # been created, so that we further descend into properties - existing = _instance(row, None) - if existing is not None \ - and key in dict_ \ - and existing is not dict_[key]: - util.warn( - "Multiple rows returned with " - "uselist=False for eagerly-loaded attribute '%s' " - % self) - return new_execute, existing_execute + self._create_scalar_loader(context, key, _instance, populators) else: - def new_execute(state, dict_, row): - collection = attributes.init_state_collection( - state, dict_, key) - result_list = util.UniqueAppender(collection, - 'append_without_event') - context.attributes[(state, key)] = result_list - _instance(row, result_list) - - def existing_execute(state, dict_, row): - if (state, key) in context.attributes: - result_list = context.attributes[(state, key)] - else: - # appender_key can be absent from context.attributes - # with isnew=False when self-referential eager loading - # is used; the same instance may be present in two - # distinct sets of result columns - collection = attributes.init_state_collection(state, - dict_, key) - result_list = util.UniqueAppender( - collection, - 'append_without_event') - context.attributes[(state, key)] = result_list - _instance(row, result_list) - return new_execute, existing_execute + self._create_collection_loader( + context, key, _instance, populators) else: - return self.parent_property.\ - _get_strategy(LazyLoader).\ - create_row_processor( - context, path, - mapper, row, adapter) + self.parent_property._get_strategy((("lazy", "select"),)).\ + create_row_processor( + context, path, loadopt, + mapper, result, adapter, populators) -log.class_logger(EagerLoader) + def _create_collection_loader(self, context, key, _instance, populators): + def load_collection_from_joined_new_row(state, dict_, row): + collection = attributes.init_state_collection( + state, dict_, key) + result_list = util.UniqueAppender(collection, + 'append_without_event') + context.attributes[(state, key)] = result_list + inst = _instance(row) + if inst is not None: + result_list.append(inst) -class EagerLazyOption(StrategizedOption): - def __init__(self, key, lazy=True, chained=False, - propagate_to_loaders=True - ): - super(EagerLazyOption, self).__init__(key) - self.lazy = lazy - self.chained = chained - self.propagate_to_loaders = propagate_to_loaders - self.strategy_cls = factory(lazy) - - @property - def is_eager(self): - return self.lazy in (False, 'joined', 'subquery') - - @property - def is_chained(self): - return self.is_eager and self.chained + def load_collection_from_joined_existing_row(state, dict_, row): + if (state, key) in context.attributes: + result_list = context.attributes[(state, key)] + else: + # appender_key can be absent from context.attributes + # with isnew=False when self-referential eager loading + # is used; the same instance may be present in two + # distinct sets of result columns + collection = attributes.init_state_collection( + state, dict_, key) + result_list = util.UniqueAppender( + collection, + 'append_without_event') + context.attributes[(state, key)] = result_list + inst = _instance(row) + if inst is not None: + result_list.append(inst) - def get_strategy_class(self): - return self.strategy_cls + def load_collection_from_joined_exec(state, dict_, row): + _instance(row) -def factory(identifier): - if identifier is False or identifier == 'joined': - return EagerLoader - elif identifier is None or identifier == 'noload': - return NoLoader - elif identifier is False or identifier == 'select': - return LazyLoader - elif identifier == 'subquery': - return SubqueryLoader - else: - return LazyLoader - - - -class EagerJoinOption(PropertyOption): - - def __init__(self, key, innerjoin, chained=False): - super(EagerJoinOption, self).__init__(key) - self.innerjoin = innerjoin - self.chained = chained - - def is_chained(self): - return self.chained + populators["new"].append((self.key, load_collection_from_joined_new_row)) + populators["existing"].append( + (self.key, load_collection_from_joined_existing_row)) + if context.invoke_all_eagers: + populators["eager"].append( + (self.key, load_collection_from_joined_exec)) - def process_query_property(self, query, paths, mappers): - if self.is_chained(): - for path in paths: - query._attributes[("eager_join_type", path)] = self.innerjoin - else: - query._attributes[("eager_join_type", paths[-1])] = self.innerjoin - -class LoadEagerFromAliasOption(PropertyOption): - - def __init__(self, key, alias=None): - super(LoadEagerFromAliasOption, self).__init__(key) - if alias is not None: - if not isinstance(alias, basestring): - m, alias, is_aliased_class = mapperutil._entity_info(alias) - self.alias = alias + def _create_scalar_loader(self, context, key, _instance, populators): + def load_scalar_from_joined_new_row(state, dict_, row): + # set a scalar object instance directly on the parent + # object, bypassing InstrumentedAttribute event handlers. + dict_[key] = _instance(row) - def process_query_property(self, query, paths, mappers): - if self.alias is not None: - if isinstance(self.alias, basestring): - mapper = mappers[-1] - (root_mapper, propname) = paths[-1][-2:] - prop = mapper.get_property(propname, resolve_synonyms=True) - self.alias = prop.target.alias(self.alias) - query._attributes[ - ("user_defined_eager_row_processor", - interfaces._reduce_path(paths[-1])) - ] = sql_util.ColumnAdapter(self.alias) - else: - (root_mapper, propname) = paths[-1][-2:] - mapper = mappers[-1] - prop = mapper.get_property(propname, resolve_synonyms=True) - adapter = query._polymorphic_adapters.get(prop.mapper, None) - query._attributes[ - ("user_defined_eager_row_processor", - interfaces._reduce_path(paths[-1]))] = adapter + def load_scalar_from_joined_existing_row(state, dict_, row): + # call _instance on the row, even though the object has + # been created, so that we further descend into properties + existing = _instance(row) -class _SingleParentValidator(interfaces.AttributeExtension): - def __init__(self, prop): - self.prop = prop + # conflicting value already loaded, this shouldn't happen + if key in dict_: + if existing is not dict_[key]: + util.warn( + "Multiple rows returned with " + "uselist=False for eagerly-loaded attribute '%s' " + % self) + else: + # this case is when one row has multiple loads of the + # same entity (e.g. via aliasing), one has an attribute + # that the other doesn't. + dict_[key] = existing - def _do_check(self, state, value, oldvalue, initiator): - if value is not None: + def load_scalar_from_joined_exec(state, dict_, row): + _instance(row) + + populators["new"].append((self.key, load_scalar_from_joined_new_row)) + populators["existing"].append( + (self.key, load_scalar_from_joined_existing_row)) + if context.invoke_all_eagers: + populators["eager"].append((self.key, load_scalar_from_joined_exec)) + + +def single_parent_validator(desc, prop): + def _do_check(state, value, oldvalue, initiator): + if value is not None and initiator.key == prop.key: hasparent = initiator.hasparent(attributes.instance_state(value)) - if hasparent and oldvalue is not value: + if hasparent and oldvalue is not value: raise sa_exc.InvalidRequestError( "Instance %s is already associated with an instance " "of %s via its %s attribute, and is only allowed a " - "single parent." % - (mapperutil.instance_str(value), state.class_, self.prop) + "single parent." % + (orm_util.instance_str(value), state.class_, prop) ) return value - - def append(self, state, value, initiator): - return self._do_check(state, value, None, initiator) - def set(self, state, value, oldvalue, initiator): - return self._do_check(state, value, oldvalue, initiator) + def append(state, value, initiator): + return _do_check(state, value, None, initiator) + def set_(state, value, oldvalue, initiator): + return _do_check(state, value, oldvalue, initiator) + event.listen( + desc, 'append', append, raw=True, retval=True, + active_history=True) + event.listen( + desc, 'set', set_, raw=True, retval=True, + active_history=True) diff --git a/sqlalchemy/orm/sync.py b/sqlalchemy/orm/sync.py index 30daacb..880428b 100644 --- a/sqlalchemy/orm/sync.py +++ b/sqlalchemy/orm/sync.py @@ -1,98 +1,140 @@ -# mapper/sync.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# orm/sync.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""private module containing functions used for copying data +"""private module containing functions used for copying data between instances based on join conditions. + """ -from sqlalchemy.orm import exc, util as mapperutil +from . import exc, util as orm_util, attributes + + +def populate(source, source_mapper, dest, dest_mapper, + synchronize_pairs, uowcommit, flag_cascaded_pks): + source_dict = source.dict + dest_dict = dest.dict -def populate(source, source_mapper, dest, dest_mapper, - synchronize_pairs, uowcommit, passive_updates): for l, r in synchronize_pairs: try: - value = source_mapper._get_state_attr_by_column(source, l) + # inline of source_mapper._get_state_attr_by_column + prop = source_mapper._columntoproperty[l] + value = source.manager[prop.key].impl.get(source, source_dict, + attributes.PASSIVE_OFF) except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, dest_mapper, r) try: - dest_mapper._set_state_attr_by_column(dest, r, value) + # inline of dest_mapper._set_state_attr_by_column + prop = dest_mapper._columntoproperty[r] + dest.manager[prop.key].impl.set(dest, dest_dict, value, None) except exc.UnmappedColumnError: _raise_col_to_prop(True, source_mapper, l, dest_mapper, r) - - # techically the "r.primary_key" check isn't + + # technically the "r.primary_key" check isn't # needed here, but we check for this condition to limit # how often this logic is invoked for memory/performance # reasons, since we only need this info for a primary key # destination. - if l.primary_key and r.primary_key and \ - r.references(l) and passive_updates: + if flag_cascaded_pks and l.primary_key and \ + r.primary_key and \ + r.references(l): uowcommit.attributes[("pk_cascaded", dest, r)] = True + +def bulk_populate_inherit_keys( + source_dict, source_mapper, synchronize_pairs): + # a simplified version of populate() used by bulk insert mode + for l, r in synchronize_pairs: + try: + prop = source_mapper._columntoproperty[l] + value = source_dict[prop.key] + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, source_mapper, r) + + try: + prop = source_mapper._columntoproperty[r] + source_dict[prop.key] = value + except exc.UnmappedColumnError: + _raise_col_to_prop(True, source_mapper, l, source_mapper, r) + + def clear(dest, dest_mapper, synchronize_pairs): for l, r in synchronize_pairs: - if r.primary_key: + if r.primary_key and \ + dest_mapper._get_state_attr_by_column( + dest, dest.dict, r) not in orm_util._none_set: + raise AssertionError( - "Dependency rule tried to blank-out primary key " - "column '%s' on instance '%s'" % - (r, mapperutil.state_str(dest)) - ) + "Dependency rule tried to blank-out primary key " + "column '%s' on instance '%s'" % + (r, orm_util.state_str(dest)) + ) try: - dest_mapper._set_state_attr_by_column(dest, r, None) + dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None) except exc.UnmappedColumnError: _raise_col_to_prop(True, None, l, dest_mapper, r) + def update(source, source_mapper, dest, old_prefix, synchronize_pairs): for l, r in synchronize_pairs: try: - oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l) - value = source_mapper._get_state_attr_by_column(source, l) + oldvalue = source_mapper._get_committed_attr_by_column( + source.obj(), l) + value = source_mapper._get_state_attr_by_column( + source, source.dict, l, passive=attributes.PASSIVE_OFF) except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, None, r) dest[r.key] = value dest[old_prefix + r.key] = oldvalue + def populate_dict(source, source_mapper, dict_, synchronize_pairs): for l, r in synchronize_pairs: try: - value = source_mapper._get_state_attr_by_column(source, l) + value = source_mapper._get_state_attr_by_column( + source, source.dict, l, passive=attributes.PASSIVE_OFF) except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, None, r) dict_[r.key] = value + def source_modified(uowcommit, source, source_mapper, synchronize_pairs): - """return true if the source object has changes from an old to a + """return true if the source object has changes from an old to a new value on the given synchronize pairs - + """ for l, r in synchronize_pairs: try: - prop = source_mapper._get_col_to_prop(l) + prop = source_mapper._columntoproperty[l] except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, None, r) - history = uowcommit.get_attribute_history(source, prop.key, passive=True) - if len(history.deleted): + history = uowcommit.get_attribute_history( + source, prop.key, attributes.PASSIVE_NO_INITIALIZE) + if bool(history.deleted): return True else: return False -def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column): + +def _raise_col_to_prop(isdest, source_mapper, source_column, + dest_mapper, dest_column): if isdest: raise exc.UnmappedColumnError( - "Can't execute sync rule for destination column '%s'; " - "mapper '%s' does not map this column. Try using an explicit" - " `foreign_keys` collection which does not include this column " - "(or use a viewonly=True relation)." % (dest_column, source_mapper) - ) + "Can't execute sync rule for " + "destination column '%s'; mapper '%s' does not map " + "this column. Try using an explicit `foreign_keys` " + "collection which does not include this column (or use " + "a viewonly=True relation)." % (dest_column, dest_mapper)) else: raise exc.UnmappedColumnError( - "Can't execute sync rule for source column '%s'; mapper '%s' " - "does not map this column. Try using an explicit `foreign_keys`" - " collection which does not include destination column '%s' (or " - "use a viewonly=True relation)." % - (source_column, source_mapper, dest_column) - ) + "Can't execute sync rule for " + "source column '%s'; mapper '%s' does not map this " + "column. Try using an explicit `foreign_keys` " + "collection which does not include destination column " + "'%s' (or use a viewonly=True relation)." % + (source_column, source_mapper, dest_column)) diff --git a/sqlalchemy/orm/unitofwork.py b/sqlalchemy/orm/unitofwork.py index 30b0b61..3a39a30 100644 --- a/sqlalchemy/orm/unitofwork.py +++ b/sqlalchemy/orm/unitofwork.py @@ -1,781 +1,672 @@ # orm/unitofwork.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""The internals for the Unit Of Work system. +"""The internals for the unit of work system. -Includes hooks into the attributes package enabling the routing of -change events to Unit Of Work objects, as well as the flush() -mechanism which creates a dependency structure that executes change -operations. - -A Unit of Work is essentially a system of maintaining a graph of -in-memory objects and their modified state. Objects are maintained as -unique against their primary key identity using an *identity map* -pattern. The Unit of Work then maintains lists of objects that are -new, dirty, or deleted and provides the capability to flush all those -changes at once. +The session's flush() process passes objects to a contextual object +here, which assembles flush tasks based on mappers and their properties, +organizes them in order of dependency, and executes. """ -from sqlalchemy import util, log, topological -from sqlalchemy.orm import attributes, interfaces -from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm.mapper import _state_mapper +from .. import util, event +from ..util import topological +from . import attributes, persistence, util as orm_util +from . import exc as orm_exc +import itertools -# Load lazily -object_session = None -_state_session = None -class UOWEventHandler(interfaces.AttributeExtension): - """An event handler added to all relationship attributes which handles - session cascade operations. +def track_cascade_events(descriptor, prop): + """Establish event listeners on object attributes which handle + cascade-on-set/append. + """ - - active_history = False - - def __init__(self, key): - self.key = key + key = prop.key - def append(self, state, item, initiator): - # process "save_update" cascade rules for when an instance is appended to the list of another instance - sess = _state_session(state) + def append(state, item, initiator): + # process "save_update" cascade rules for when + # an instance is appended to the list of another instance + + if item is None: + return + + sess = state.session if sess: - prop = _state_mapper(state).get_property(self.key) - if prop.cascade.save_update and item not in sess: - sess.add(item) + if sess._warn_on_events: + sess._flush_warning("collection append") + + prop = state.manager.mapper._props[key] + item_state = attributes.instance_state(item) + if prop._cascade.save_update and \ + (prop.cascade_backrefs or key == initiator.key) and \ + not sess._contains_state(item_state): + sess._save_or_update_state(item_state) return item - - def remove(self, state, item, initiator): - sess = _state_session(state) - if sess: - prop = _state_mapper(state).get_property(self.key) - # expunge pending orphans - if prop.cascade.delete_orphan and \ - item in sess.new and \ - prop.mapper._is_orphan(attributes.instance_state(item)): - sess.expunge(item) - def set(self, state, newvalue, oldvalue, initiator): - # process "save_update" cascade rules for when an instance is attached to another instance + def remove(state, item, initiator): + if item is None: + return + + sess = state.session + if sess: + + prop = state.manager.mapper._props[key] + + if sess._warn_on_events: + sess._flush_warning( + "collection remove" + if prop.uselist + else "related attribute delete") + + # expunge pending orphans + item_state = attributes.instance_state(item) + if prop._cascade.delete_orphan and \ + item_state in sess._new and \ + prop.mapper._is_orphan(item_state): + sess.expunge(item) + + def set_(state, newvalue, oldvalue, initiator): + # process "save_update" cascade rules for when an instance + # is attached to another instance if oldvalue is newvalue: return newvalue - sess = _state_session(state) + + sess = state.session if sess: - prop = _state_mapper(state).get_property(self.key) - if newvalue is not None and prop.cascade.save_update and newvalue not in sess: - sess.add(newvalue) - if prop.cascade.delete_orphan and oldvalue in sess.new and \ - prop.mapper._is_orphan(attributes.instance_state(oldvalue)): - sess.expunge(oldvalue) + + if sess._warn_on_events: + sess._flush_warning("related attribute set") + + prop = state.manager.mapper._props[key] + if newvalue is not None: + newvalue_state = attributes.instance_state(newvalue) + if prop._cascade.save_update and \ + (prop.cascade_backrefs or key == initiator.key) and \ + not sess._contains_state(newvalue_state): + sess._save_or_update_state(newvalue_state) + + if oldvalue is not None and \ + oldvalue is not attributes.NEVER_SET and \ + oldvalue is not attributes.PASSIVE_NO_RESULT and \ + prop._cascade.delete_orphan: + # possible to reach here with attributes.NEVER_SET ? + oldvalue_state = attributes.instance_state(oldvalue) + + if oldvalue_state in sess._new and \ + prop.mapper._is_orphan(oldvalue_state): + sess.expunge(oldvalue) return newvalue + event.listen(descriptor, 'append', append, raw=True, retval=True) + event.listen(descriptor, 'remove', remove, raw=True, retval=True) + event.listen(descriptor, 'set', set_, raw=True, retval=True) + class UOWTransaction(object): - """Handles the details of organizing and executing transaction - tasks during a UnitOfWork object's flush() operation. - - The central operation is to form a graph of nodes represented by the - ``UOWTask`` class, which is then traversed by a ``UOWExecutor`` object - that issues SQL and instance-synchronizing operations via the related - packages. - """ - def __init__(self, session): self.session = session - self.mapper_flush_opts = session._mapper_flush_opts - # stores tuples of mapper/dependent mapper pairs, - # representing a partial ordering fed into topological sort + # dictionary used by external actors to + # store arbitrary state information. + self.attributes = {} + + # dictionary of mappers to sets of + # DependencyProcessors, which are also + # set to be part of the sorted flush actions, + # which have that mapper as a parent. + self.deps = util.defaultdict(set) + + # dictionary of mappers to sets of InstanceState + # items pending for flush which have that mapper + # as a parent. + self.mappers = util.defaultdict(set) + + # a dictionary of Preprocess objects, which gather + # additional states impacted by the flush + # and determine if a flush action is needed + self.presort_actions = {} + + # dictionary of PostSortRec objects, each + # one issues work during the flush within + # a certain ordering. + self.postsort_actions = {} + + # a set of 2-tuples, each containing two + # PostSortRec objects where the second + # is dependent on the first being executed + # first self.dependencies = set() - # dictionary of mappers to UOWTasks - self.tasks = {} + # dictionary of InstanceState-> (isdelete, listonly) + # tuples, indicating if this state is to be deleted + # or insert/updated, or just refreshed + self.states = {} + + # tracks InstanceStates which will be receiving + # a "post update" call. Keys are mappers, + # values are a set of states and a set of the + # columns which should be included in the update. + self.post_update_states = util.defaultdict(lambda: (set(), set())) + + @property + def has_work(self): + return bool(self.states) + + def was_already_deleted(self, state): + """return true if the given state is expired and was deleted + previously. + """ + if state.expired: + try: + state._load_expired(state, attributes.PASSIVE_OFF) + except orm_exc.ObjectDeletedError: + self.session._remove_newly_deleted([state]) + return True + return False + + def is_deleted(self, state): + """return true if the given state is marked as deleted + within this uowtransaction.""" + + return state in self.states and self.states[state][0] + + def memo(self, key, callable_): + if key in self.attributes: + return self.attributes[key] + else: + self.attributes[key] = ret = callable_() + return ret + + def remove_state_actions(self, state): + """remove pending actions for a state from the uowtransaction.""" + + isdelete = self.states[state][0] + + self.states[state] = (isdelete, True) + + def get_attribute_history(self, state, key, + passive=attributes.PASSIVE_NO_INITIALIZE): + """facade to attributes.get_state_history(), including + caching of results.""" - # dictionary used by external actors to store arbitrary state - # information. - self.attributes = {} - - self.processors = set() - - def get_attribute_history(self, state, key, passive=True): hashkey = ("history", state, key) # cache the objects, not the states; the strong reference here # prevents newly loaded objects from being dereferenced during the # flush process + if hashkey in self.attributes: - (history, cached_passive) = self.attributes[hashkey] - # if the cached lookup was "passive" and now we want non-passive, do a non-passive - # lookup and re-cache - if cached_passive and not passive: - history = attributes.get_state_history(state, key, passive=False) - self.attributes[hashkey] = (history, passive) - else: - history = attributes.get_state_history(state, key, passive=passive) - self.attributes[hashkey] = (history, passive) + history, state_history, cached_passive = self.attributes[hashkey] + # if the cached lookup was "passive" and now + # we want non-passive, do a non-passive lookup and re-cache - if not history or not state.get_impl(key).uses_objects: - return history + if not cached_passive & attributes.SQL_OK \ + and passive & attributes.SQL_OK: + impl = state.manager[key].impl + history = impl.get_history(state, state.dict, + attributes.PASSIVE_OFF | + attributes.LOAD_AGAINST_COMMITTED) + if history and impl.uses_objects: + state_history = history.as_state() + else: + state_history = history + self.attributes[hashkey] = (history, state_history, passive) else: - return history.as_state() + impl = state.manager[key].impl + # TODO: store the history as (state, object) tuples + # so we don't have to keep converting here + history = impl.get_history(state, state.dict, passive | + attributes.LOAD_AGAINST_COMMITTED) + if history and impl.uses_objects: + state_history = history.as_state() + else: + state_history = history + self.attributes[hashkey] = (history, state_history, + passive) - def register_object(self, state, isdelete=False, - listonly=False, postupdate=False, post_update_cols=None): - - # if object is not in the overall session, do nothing + return state_history + + def has_dep(self, processor): + return (processor, True) in self.presort_actions + + def register_preprocessor(self, processor, fromparent): + key = (processor, fromparent) + if key not in self.presort_actions: + self.presort_actions[key] = Preprocess(processor, fromparent) + + def register_object(self, state, isdelete=False, + listonly=False, cancel_delete=False, + operation=None, prop=None): if not self.session._contains_state(state): - return + # this condition is normal when objects are registered + # as part of a relationship cascade operation. it should + # not occur for the top-level register from Session.flush(). + if not state.deleted and operation is not None: + util.warn("Object of type %s not in session, %s operation " + "along '%s' will not proceed" % + (orm_util.state_class_str(state), operation, prop)) + return False - mapper = _state_mapper(state) + if state not in self.states: + mapper = state.manager.mapper - task = self.get_task_by_mapper(mapper) - if postupdate: - task.append_postupdate(state, post_update_cols) + if mapper not in self.mappers: + self._per_mapper_flush_actions(mapper) + + self.mappers[mapper].add(state) + self.states[state] = (isdelete, listonly) else: - task.append(state, listonly=listonly, isdelete=isdelete) + if not listonly and (isdelete or cancel_delete): + self.states[state] = (isdelete, False) + return True - # ensure the mapper for this object has had its - # DependencyProcessors added. - if mapper not in self.processors: - mapper._register_processors(self) - self.processors.add(mapper) + def issue_post_update(self, state, post_update_cols): + mapper = state.manager.mapper.base_mapper + states, cols = self.post_update_states[mapper] + states.add(state) + cols.update(post_update_cols) - if mapper.base_mapper not in self.processors: - mapper.base_mapper._register_processors(self) - self.processors.add(mapper.base_mapper) - - def set_row_switch(self, state): - """mark a deleted object as a 'row switch'. + def _per_mapper_flush_actions(self, mapper): + saves = SaveUpdateAll(self, mapper.base_mapper) + deletes = DeleteAll(self, mapper.base_mapper) + self.dependencies.add((saves, deletes)) - this indicates that an INSERT statement elsewhere corresponds to this DELETE; - the INSERT is converted to an UPDATE and the DELETE does not occur. - - """ - mapper = _state_mapper(state) - task = self.get_task_by_mapper(mapper) - taskelement = task._objects[state] - taskelement.isdelete = "rowswitch" - - def is_deleted(self, state): - """return true if the given state is marked as deleted within this UOWTransaction.""" + for dep in mapper._dependency_processors: + dep.per_property_preprocessors(self) - mapper = _state_mapper(state) - task = self.get_task_by_mapper(mapper) - return task.is_deleted(state) + for prop in mapper.relationships: + if prop.viewonly: + continue + dep = prop._dependency_processor + dep.per_property_preprocessors(self) - def get_task_by_mapper(self, mapper, dontcreate=False): - """return UOWTask element corresponding to the given mapper. + @util.memoized_property + def _mapper_for_dep(self): + """return a dynamic mapping of (Mapper, DependencyProcessor) to + True or False, indicating if the DependencyProcessor operates + on objects of that Mapper. - Will create a new UOWTask, including a UOWTask corresponding to the - "base" inherited mapper, if needed, unless the dontcreate flag is True. - - """ - try: - return self.tasks[mapper] - except KeyError: - if dontcreate: - return None - - base_mapper = mapper.base_mapper - if base_mapper in self.tasks: - base_task = self.tasks[base_mapper] - else: - self.tasks[base_mapper] = base_task = UOWTask(self, base_mapper) - base_mapper._register_dependencies(self) - - if mapper not in self.tasks: - self.tasks[mapper] = task = UOWTask(self, mapper, base_task=base_task) - mapper._register_dependencies(self) - else: - task = self.tasks[mapper] - - return task - - def register_dependency(self, mapper, dependency): - """register a dependency between two mappers. - - Called by ``mapper.PropertyLoader`` to register the objects - handled by one mapper being dependent on the objects handled - by another. + The result is stored in the dictionary persistently once + calculated. """ - # correct for primary mapper - # also convert to the "base mapper", the parentmost task at the top of an inheritance chain - # dependency sorting is done via non-inheriting mappers only, dependencies between mappers - # in the same inheritance chain is done at the per-object level - mapper = mapper.primary_mapper().base_mapper - dependency = dependency.primary_mapper().base_mapper + return util.PopulateDict( + lambda tup: tup[0]._props.get(tup[1].key) is tup[1].prop + ) - self.dependencies.add((mapper, dependency)) + def filter_states_for_dep(self, dep, states): + """Filter the given list of InstanceStates to those relevant to the + given DependencyProcessor. - def register_processor(self, mapper, processor, mapperfrom): - """register a dependency processor, corresponding to - operations which occur between two mappers. - """ - # correct for primary mapper - mapper = mapper.primary_mapper() - mapperfrom = mapperfrom.primary_mapper() + mapper_for_dep = self._mapper_for_dep + return [s for s in states if mapper_for_dep[(s.manager.mapper, dep)]] - task = self.get_task_by_mapper(mapper) - targettask = self.get_task_by_mapper(mapperfrom) - up = UOWDependencyProcessor(processor, targettask) - task.dependencies.add(up) + def states_for_mapper_hierarchy(self, mapper, isdelete, listonly): + checktup = (isdelete, listonly) + for mapper in mapper.base_mapper.self_and_descendants: + for state in self.mappers[mapper]: + if self.states[state] == checktup: + yield state - def execute(self): - """Execute this UOWTransaction. + def _generate_actions(self): + """Generate the full, unsorted collection of PostSortRecs as + well as dependency pairs for this UOWTransaction. - This will organize all collected UOWTasks into a dependency-sorted - list which is then traversed using the traversal scheme - encoded in the UOWExecutor class. Operations to mappers and dependency - processors are fired off in order to issue SQL to the database and - synchronize instance attributes with database values and related - foreign key values.""" - - # pre-execute dependency processors. this process may - # result in new tasks, objects and/or dependency processors being added, - # particularly with 'delete-orphan' cascade rules. - # keep running through the full list of tasks until all - # objects have been processed. + """ + # execute presort_actions, until all states + # have been processed. a presort_action might + # add new states to the uow. while True: ret = False - for task in self.tasks.values(): - for up in list(task.dependencies): - if up.preexecute(self): - ret = True + for action in list(self.presort_actions.values()): + if action.execute(self): + ret = True if not ret: break - tasks = self._sort_dependencies() - if self._should_log_info(): - self.logger.info("Task dump:\n%s", self._dump(tasks)) - UOWExecutor().execute(self, tasks) - self.logger.info("Execute Complete") + # see if the graph of mapper dependencies has cycles. + self.cycles = cycles = topological.find_cycles( + self.dependencies, + list(self.postsort_actions.values())) - def _dump(self, tasks): - from uowdumper import UOWDumper - return UOWDumper.dump(tasks) + if cycles: + # if yes, break the per-mapper actions into + # per-state actions + convert = dict( + (rec, set(rec.per_state_flush_actions(self))) + for rec in cycles + ) - @property - def elements(self): - """Iterate UOWTaskElements.""" - - for task in self.tasks.itervalues(): - for elem in task.elements: - yield elem + # rewrite the existing dependencies to point to + # the per-state actions for those per-mapper actions + # that were broken up. + for edge in list(self.dependencies): + if None in edge or \ + edge[0].disabled or edge[1].disabled or \ + cycles.issuperset(edge): + self.dependencies.remove(edge) + elif edge[0] in cycles: + self.dependencies.remove(edge) + for dep in convert[edge[0]]: + self.dependencies.add((dep, edge[1])) + elif edge[1] in cycles: + self.dependencies.remove(edge) + for dep in convert[edge[1]]: + self.dependencies.add((edge[0], dep)) + + return set([a for a in self.postsort_actions.values() + if not a.disabled + ] + ).difference(cycles) + + def execute(self): + postsort_actions = self._generate_actions() + + # sort = topological.sort(self.dependencies, postsort_actions) + # print "--------------" + # print "\ndependencies:", self.dependencies + # print "\ncycles:", self.cycles + # print "\nsort:", list(sort) + # print "\nCOUNT OF POSTSORT ACTIONS", len(postsort_actions) + + # execute + if self.cycles: + for set_ in topological.sort_as_subsets( + self.dependencies, + postsort_actions): + while set_: + n = set_.pop() + n.execute_aggregate(self, set_) + else: + for rec in topological.sort( + self.dependencies, + postsort_actions): + rec.execute(self) def finalize_flush_changes(self): - """mark processed objects as clean / deleted after a successful flush(). + """mark processed objects as clean / deleted after a successful + flush(). this method is called within the flush() method after the execute() method has succeeded and the transaction has been committed. + """ + if not self.states: + return - for elem in self.elements: - if elem.isdelete: - self.session._remove_newly_deleted(elem.state) - elif not elem.listonly: - self.session._register_newly_persistent(elem.state) + states = set(self.states) + isdel = set( + s for (s, (isdelete, listonly)) in self.states.items() + if isdelete + ) + other = states.difference(isdel) + if isdel: + self.session._remove_newly_deleted(isdel) + if other: + self.session._register_newly_persistent(other) - def _sort_dependencies(self): - nodes = topological.sort_with_cycles(self.dependencies, - [t.mapper for t in self.tasks.itervalues() if t.base_task is t] + +class IterateMappersMixin(object): + def _mappers(self, uow): + if self.fromparent: + return iter( + m for m in + self.dependency_processor.parent.self_and_descendants + if uow._mapper_for_dep[(m, self.dependency_processor)] + ) + else: + return self.dependency_processor.mapper.self_and_descendants + + +class Preprocess(IterateMappersMixin): + def __init__(self, dependency_processor, fromparent): + self.dependency_processor = dependency_processor + self.fromparent = fromparent + self.processed = set() + self.setup_flush_actions = False + + def execute(self, uow): + delete_states = set() + save_states = set() + + for mapper in self._mappers(uow): + for state in uow.mappers[mapper].difference(self.processed): + (isdelete, listonly) = uow.states[state] + if not listonly: + if isdelete: + delete_states.add(state) + else: + save_states.add(state) + + if delete_states: + self.dependency_processor.presort_deletes(uow, delete_states) + self.processed.update(delete_states) + if save_states: + self.dependency_processor.presort_saves(uow, save_states) + self.processed.update(save_states) + + if (delete_states or save_states): + if not self.setup_flush_actions and ( + self.dependency_processor. + prop_has_changes(uow, delete_states, True) or + self.dependency_processor. + prop_has_changes(uow, save_states, False) + ): + self.dependency_processor.per_property_flush_actions(uow) + self.setup_flush_actions = True + return True + else: + return False + + +class PostSortRec(object): + disabled = False + + def __new__(cls, uow, *args): + key = (cls, ) + args + if key in uow.postsort_actions: + return uow.postsort_actions[key] + else: + uow.postsort_actions[key] = \ + ret = \ + object.__new__(cls) + return ret + + def execute_aggregate(self, uow, recs): + self.execute(uow) + + def __repr__(self): + return "%s(%s)" % ( + self.__class__.__name__, + ",".join(str(x) for x in self.__dict__.values()) ) - ret = [] - for item, cycles in nodes: - task = self.get_task_by_mapper(item) - if cycles: - for t in task._sort_circular_dependencies( - self, - [self.get_task_by_mapper(i) for i in cycles] - ): - ret.append(t) - else: - ret.append(task) - return ret +class ProcessAll(IterateMappersMixin, PostSortRec): + def __init__(self, uow, dependency_processor, delete, fromparent): + self.dependency_processor = dependency_processor + self.delete = delete + self.fromparent = fromparent + uow.deps[dependency_processor.parent.base_mapper].\ + add(dependency_processor) -log.class_logger(UOWTransaction) - -class UOWTask(object): - """A collection of mapped states corresponding to a particular mapper.""" - - def __init__(self, uowtransaction, mapper, base_task=None): - self.uowtransaction = uowtransaction - - # base_task is the UOWTask which represents the "base mapper" - # in our mapper's inheritance chain. if the mapper does not - # inherit from any other mapper, the base_task is self. - # the _inheriting_tasks dictionary is a dictionary present only - # on the "base_task"-holding UOWTask, which maps all mappers within - # an inheritance hierarchy to their corresponding UOWTask instances. - if base_task is None: - self.base_task = self - self._inheriting_tasks = {mapper:self} + def execute(self, uow): + states = self._elements(uow) + if self.delete: + self.dependency_processor.process_deletes(uow, states) else: - self.base_task = base_task - base_task._inheriting_tasks[mapper] = self + self.dependency_processor.process_saves(uow, states) - # the Mapper which this UOWTask corresponds to + def per_state_flush_actions(self, uow): + # this is handled by SaveUpdateAll and DeleteAll, + # since a ProcessAll should unconditionally be pulled + # into per-state if either the parent/child mappers + # are part of a cycle + return iter([]) + + def __repr__(self): + return "%s(%s, delete=%s)" % ( + self.__class__.__name__, + self.dependency_processor, + self.delete + ) + + def _elements(self, uow): + for mapper in self._mappers(uow): + for state in uow.mappers[mapper]: + (isdelete, listonly) = uow.states[state] + if isdelete == self.delete and not listonly: + yield state + + +class IssuePostUpdate(PostSortRec): + def __init__(self, uow, mapper, isdelete): + self.mapper = mapper + self.isdelete = isdelete + + def execute(self, uow): + states, cols = uow.post_update_states[self.mapper] + states = [s for s in states if uow.states[s][0] == self.isdelete] + + persistence.post_update(self.mapper, states, uow, cols) + + +class SaveUpdateAll(PostSortRec): + def __init__(self, uow, mapper): + self.mapper = mapper + assert mapper is mapper.base_mapper + + def execute(self, uow): + persistence.save_obj(self.mapper, + uow.states_for_mapper_hierarchy( + self.mapper, False, False), + uow + ) + + def per_state_flush_actions(self, uow): + states = list(uow.states_for_mapper_hierarchy( + self.mapper, False, False)) + base_mapper = self.mapper.base_mapper + delete_all = DeleteAll(uow, base_mapper) + for state in states: + # keep saves before deletes - + # this ensures 'row switch' operations work + action = SaveUpdateState(uow, state, base_mapper) + uow.dependencies.add((action, delete_all)) + yield action + + for dep in uow.deps[self.mapper]: + states_for_prop = uow.filter_states_for_dep(dep, states) + dep.per_state_flush_actions(uow, states_for_prop, False) + + +class DeleteAll(PostSortRec): + def __init__(self, uow, mapper): + self.mapper = mapper + assert mapper is mapper.base_mapper + + def execute(self, uow): + persistence.delete_obj(self.mapper, + uow.states_for_mapper_hierarchy( + self.mapper, True, False), + uow + ) + + def per_state_flush_actions(self, uow): + states = list(uow.states_for_mapper_hierarchy( + self.mapper, True, False)) + base_mapper = self.mapper.base_mapper + save_all = SaveUpdateAll(uow, base_mapper) + for state in states: + # keep saves before deletes - + # this ensures 'row switch' operations work + action = DeleteState(uow, state, base_mapper) + uow.dependencies.add((save_all, action)) + yield action + + for dep in uow.deps[self.mapper]: + states_for_prop = uow.filter_states_for_dep(dep, states) + dep.per_state_flush_actions(uow, states_for_prop, True) + + +class ProcessState(PostSortRec): + def __init__(self, uow, dependency_processor, delete, state): + self.dependency_processor = dependency_processor + self.delete = delete + self.state = state + + def execute_aggregate(self, uow, recs): + cls_ = self.__class__ + dependency_processor = self.dependency_processor + delete = self.delete + our_recs = [r for r in recs + if r.__class__ is cls_ and + r.dependency_processor is dependency_processor and + r.delete is delete] + recs.difference_update(our_recs) + states = [self.state] + [r.state for r in our_recs] + if delete: + dependency_processor.process_deletes(uow, states) + else: + dependency_processor.process_saves(uow, states) + + def __repr__(self): + return "%s(%s, %s, delete=%s)" % ( + self.__class__.__name__, + self.dependency_processor, + orm_util.state_str(self.state), + self.delete + ) + + +class SaveUpdateState(PostSortRec): + def __init__(self, uow, state, mapper): + self.state = state self.mapper = mapper - # mapping of InstanceState -> UOWTaskElement - self._objects = {} - - self.dependent_tasks = [] - self.dependencies = set() - self.cyclical_dependencies = set() - - @util.memoized_property - def inheriting_mappers(self): - return list(self.mapper.polymorphic_iterator()) - - @property - def polymorphic_tasks(self): - """Return an iterator of UOWTask objects corresponding to the - inheritance sequence of this UOWTask's mapper. - - e.g. if mapper B and mapper C inherit from mapper A, and - mapper D inherits from B: - - mapperA -> mapperB -> mapperD - -> mapperC - - the inheritance sequence starting at mapper A is a depth-first - traversal: - - [mapperA, mapperB, mapperD, mapperC] - - this method will therefore return - - [UOWTask(mapperA), UOWTask(mapperB), UOWTask(mapperD), - UOWTask(mapperC)] - - The concept of "polymporphic iteration" is adapted into - several property-based iterators which return object - instances, UOWTaskElements and UOWDependencyProcessors in an - order corresponding to this sequence of parent UOWTasks. This - is used to issue operations related to inheritance-chains of - mappers in the proper order based on dependencies between - those mappers. - - """ - for mapper in self.inheriting_mappers: - t = self.base_task._inheriting_tasks.get(mapper, None) - if t is not None: - yield t - - def is_empty(self): - """return True if this UOWTask is 'empty', meaning it has no child items. - - used only for debugging output. - """ - - return not self._objects and not self.dependencies - - def append(self, state, listonly=False, isdelete=False): - if state not in self._objects: - self._objects[state] = rec = UOWTaskElement(state) - else: - rec = self._objects[state] - - rec.update(listonly, isdelete) - - def append_postupdate(self, state, post_update_cols): - """issue a 'post update' UPDATE statement via this object's mapper immediately. - - this operation is used only with relationships that specify the `post_update=True` - flag. - """ - - # postupdates are UPDATED immeditely (for now) - # convert post_update_cols list to a Set so that __hash__() is used to compare columns - # instead of __eq__() - self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=set(post_update_cols)) - - def __contains__(self, state): - """return True if the given object is contained within this UOWTask or inheriting tasks.""" - - for task in self.polymorphic_tasks: - if state in task._objects: - return True - else: - return False - - def is_deleted(self, state): - """return True if the given object is marked as to be deleted within this UOWTask.""" - - try: - return self._objects[state].isdelete - except KeyError: - return False - - def _polymorphic_collection(fn): - """return a property that will adapt the collection returned by the - given callable into a polymorphic traversal.""" - - @property - def collection(self): - for task in self.polymorphic_tasks: - for rec in fn(task): - yield rec - return collection - - def _polymorphic_collection_filtered(fn): - - def collection(self, mappers): - for task in self.polymorphic_tasks: - if task.mapper in mappers: - for rec in fn(task): - yield rec - return collection - - @property - def elements(self): - return self._objects.values() - - @_polymorphic_collection - def polymorphic_elements(self): - return self.elements - - @_polymorphic_collection_filtered - def filter_polymorphic_elements(self): - return self.elements - - @property - def polymorphic_tosave_elements(self): - return [rec for rec in self.polymorphic_elements if not rec.isdelete] - - @property - def polymorphic_todelete_elements(self): - return [rec for rec in self.polymorphic_elements if rec.isdelete] - - @property - def polymorphic_tosave_objects(self): - return [ - rec.state for rec in self.polymorphic_elements - if rec.state is not None and not rec.listonly and rec.isdelete is False - ] - - @property - def polymorphic_todelete_objects(self): - return [ - rec.state for rec in self.polymorphic_elements - if rec.state is not None and not rec.listonly and rec.isdelete is True - ] - - @_polymorphic_collection - def polymorphic_dependencies(self): - return self.dependencies - - @_polymorphic_collection - def polymorphic_cyclical_dependencies(self): - return self.cyclical_dependencies - - def _sort_circular_dependencies(self, trans, cycles): - """Topologically sort individual entities with row-level dependencies. - - Builds a modified UOWTask structure, and is invoked when the - per-mapper topological structure is found to have cycles. - - """ - - dependencies = {} - def set_processor_for_state(state, depprocessor, target_state, isdelete): - if state not in dependencies: - dependencies[state] = {} - tasks = dependencies[state] - if depprocessor not in tasks: - tasks[depprocessor] = UOWDependencyProcessor( - depprocessor.processor, - UOWTask(self.uowtransaction, depprocessor.targettask.mapper) - ) - tasks[depprocessor].targettask.append(target_state, isdelete=isdelete) - - cycles = set(cycles) - def dependency_in_cycles(dep): - proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper, True) - targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper, True) - return targettask in cycles and (proctask is not None and proctask in cycles) - - deps_by_targettask = {} - extradeplist = [] - for task in cycles: - for dep in task.polymorphic_dependencies: - if not dependency_in_cycles(dep): - extradeplist.append(dep) - for t in dep.targettask.polymorphic_tasks: - l = deps_by_targettask.setdefault(t, []) - l.append(dep) - - object_to_original_task = {} - tuples = [] - - for task in cycles: - for subtask in task.polymorphic_tasks: - for taskelement in subtask.elements: - state = taskelement.state - object_to_original_task[state] = subtask - if subtask not in deps_by_targettask: - continue - for dep in deps_by_targettask[subtask]: - if not dep.processor.has_dependencies or not dependency_in_cycles(dep): - continue - (processor, targettask) = (dep.processor, dep.targettask) - isdelete = taskelement.isdelete - - # list of dependent objects from this object - (added, unchanged, deleted) = dep.get_object_dependencies(state, trans, passive=True) - if not added and not unchanged and not deleted: - continue - - # the task corresponding to saving/deleting of those dependent objects - childtask = trans.get_task_by_mapper(processor.mapper) - - childlist = added + unchanged + deleted - - for o in childlist: - if o is None: - continue - - if o not in childtask: - childtask.append(o, listonly=True) - object_to_original_task[o] = childtask - - whosdep = dep.whose_dependent_on_who(state, o) - if whosdep is not None: - tuples.append(whosdep) - - if whosdep[0] is state: - set_processor_for_state(whosdep[0], dep, whosdep[0], isdelete=isdelete) - else: - set_processor_for_state(whosdep[0], dep, whosdep[1], isdelete=isdelete) - else: - # TODO: no test coverage here - set_processor_for_state(state, dep, state, isdelete=isdelete) - - t = UOWTask(self.uowtransaction, self.mapper) - t.dependencies.update(extradeplist) - - used_tasks = set() - - # rationale for "tree" sort as opposed to a straight - # dependency - keep non-dependent objects - # grouped together, so that insert ordering as determined - # by session.add() is maintained. - # An alternative might be to represent the "insert order" - # as part of the topological sort itself, which would - # eliminate the need for this step (but may make the original - # topological sort more expensive) - head = topological.sort_as_tree(tuples, object_to_original_task.iterkeys()) - if head is not None: - original_to_tasks = {} - stack = [(head, t)] - while stack: - ((state, cycles, children), parenttask) = stack.pop() - - originating_task = object_to_original_task[state] - used_tasks.add(originating_task) - - if (parenttask, originating_task) not in original_to_tasks: - task = UOWTask(self.uowtransaction, originating_task.mapper) - original_to_tasks[(parenttask, originating_task)] = task - parenttask.dependent_tasks.append(task) - else: - task = original_to_tasks[(parenttask, originating_task)] - - task.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete) - - if state in dependencies: - task.cyclical_dependencies.update(dependencies[state].itervalues()) - - stack += [(n, task) for n in children] - - ret = [t] - - # add tasks that were in the cycle, but didnt get assembled - # into the cyclical tree, to the start of the list - for t2 in cycles: - if t2 not in used_tasks and t2 is not self: - localtask = UOWTask(self.uowtransaction, t2.mapper) - for state in t2.elements: - localtask.append(state, t2.listonly, isdelete=t2._objects[state].isdelete) - for dep in t2.dependencies: - localtask.dependencies.add(dep) - ret.insert(0, localtask) - - return ret + def execute_aggregate(self, uow, recs): + cls_ = self.__class__ + mapper = self.mapper + our_recs = [r for r in recs + if r.__class__ is cls_ and + r.mapper is mapper] + recs.difference_update(our_recs) + persistence.save_obj(mapper, + [self.state] + + [r.state for r in our_recs], + uow) def __repr__(self): - return ("UOWTask(%s) Mapper: '%r'" % (hex(id(self)), self.mapper)) - -class UOWTaskElement(object): - """Corresponds to a single InstanceState to be saved, deleted, - or otherwise marked as having dependencies. A collection of - UOWTaskElements are held by a UOWTask. - - """ - def __init__(self, state): - self.state = state - self.listonly = True - self.isdelete = False - self.preprocessed = set() - - def update(self, listonly, isdelete): - if not listonly and self.listonly: - self.listonly = False - self.preprocessed.clear() - if isdelete and not self.isdelete: - self.isdelete = True - self.preprocessed.clear() - - def __repr__(self): - return "UOWTaskElement/%d: %s/%d %s" % ( - id(self), - self.state.class_.__name__, - id(self.state.obj()), - (self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save')) + return "%s(%s)" % ( + self.__class__.__name__, + orm_util.state_str(self.state) ) -class UOWDependencyProcessor(object): - """In between the saving and deleting of objects, process - dependent data, such as filling in a foreign key on a child item - from a new primary key, or deleting association rows before a - delete. This object acts as a proxy to a DependencyProcessor. - """ - def __init__(self, processor, targettask): - self.processor = processor - self.targettask = targettask - prop = processor.prop - - # define a set of mappers which - # will filter the lists of entities - # this UOWDP processes. this allows - # MapperProperties to be overridden - # at least for concrete mappers. - self._mappers = set([ - m - for m in self.processor.parent.polymorphic_iterator() - if m._props[prop.key] is prop - ]).union(self.processor.mapper.polymorphic_iterator()) - +class DeleteState(PostSortRec): + def __init__(self, uow, state, mapper): + self.state = state + self.mapper = mapper + + def execute_aggregate(self, uow, recs): + cls_ = self.__class__ + mapper = self.mapper + our_recs = [r for r in recs + if r.__class__ is cls_ and + r.mapper is mapper] + recs.difference_update(our_recs) + states = [self.state] + [r.state for r in our_recs] + persistence.delete_obj(mapper, + [s for s in states if uow.states[s][0]], + uow) + def __repr__(self): - return "UOWDependencyProcessor(%s, %s)" % (str(self.processor), str(self.targettask)) - - def __eq__(self, other): - return other.processor is self.processor and other.targettask is self.targettask - - def __hash__(self): - return hash((self.processor, self.targettask)) - - def preexecute(self, trans): - """preprocess all objects contained within this ``UOWDependencyProcessor``s target task. - - This may locate additional objects which should be part of the - transaction, such as those affected deletes, orphans to be - deleted, etc. - - Once an object is preprocessed, its ``UOWTaskElement`` is marked as processed. If subsequent - changes occur to the ``UOWTaskElement``, its processed flag is reset, and will require processing - again. - - Return True if any objects were preprocessed, or False if no - objects were preprocessed. If True is returned, the parent ``UOWTransaction`` will - ultimately call ``preexecute()`` again on all processors until no new objects are processed. - """ - - def getobj(elem): - elem.preprocessed.add(self) - return elem.state - - ret = False - elements = [getobj(elem) for elem in - self.targettask.filter_polymorphic_elements(self._mappers) - if self not in elem.preprocessed and not elem.isdelete] - if elements: - ret = True - self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False) - - elements = [getobj(elem) for elem in - self.targettask.filter_polymorphic_elements(self._mappers) - if self not in elem.preprocessed and elem.isdelete] - if elements: - ret = True - self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True) - return ret - - def execute(self, trans, delete): - """process all objects contained within this ``UOWDependencyProcessor``s target task.""" - - - elements = [e for e in - self.targettask.filter_polymorphic_elements(self._mappers) - if bool(e.isdelete)==delete] - - self.processor.process_dependencies( - self.targettask, - [elem.state for elem in elements], - trans, - delete=delete) - - def get_object_dependencies(self, state, trans, passive): - return trans.get_attribute_history(state, self.processor.key, passive=passive) - - def whose_dependent_on_who(self, state1, state2): - """establish which object is operationally dependent amongst a parent/child - using the semantics stated by the dependency processor. - - This method is used to establish a partial ordering (set of dependency tuples) - when toplogically sorting on a per-instance basis. - - """ - return self.processor.whose_dependent_on_who(state1, state2) - -class UOWExecutor(object): - """Encapsulates the execution traversal of a UOWTransaction structure.""" - - def execute(self, trans, tasks, isdelete=None): - if isdelete is not True: - for task in tasks: - self.execute_save_steps(trans, task) - if isdelete is not False: - for task in reversed(tasks): - self.execute_delete_steps(trans, task) - - def save_objects(self, trans, task): - task.mapper._save_obj(task.polymorphic_tosave_objects, trans) - - def delete_objects(self, trans, task): - task.mapper._delete_obj(task.polymorphic_todelete_objects, trans) - - def execute_dependency(self, trans, dep, isdelete): - dep.execute(trans, isdelete) - - def execute_save_steps(self, trans, task): - self.save_objects(trans, task) - for dep in task.polymorphic_cyclical_dependencies: - self.execute_dependency(trans, dep, False) - for dep in task.polymorphic_cyclical_dependencies: - self.execute_dependency(trans, dep, True) - self.execute_cyclical_dependencies(trans, task, False) - self.execute_dependencies(trans, task) - - def execute_delete_steps(self, trans, task): - self.execute_cyclical_dependencies(trans, task, True) - self.delete_objects(trans, task) - - def execute_dependencies(self, trans, task): - polymorphic_dependencies = list(task.polymorphic_dependencies) - for dep in polymorphic_dependencies: - self.execute_dependency(trans, dep, False) - for dep in reversed(polymorphic_dependencies): - self.execute_dependency(trans, dep, True) - - def execute_cyclical_dependencies(self, trans, task, isdelete): - for t in task.dependent_tasks: - self.execute(trans, [t], isdelete) + return "%s(%s)" % ( + self.__class__.__name__, + orm_util.state_str(self.state) + ) diff --git a/sqlalchemy/orm/util.py b/sqlalchemy/orm/util.py index 63b9d56..fc0dba5 100644 --- a/sqlalchemy/orm/util.py +++ b/sqlalchemy/orm/util.py @@ -1,101 +1,159 @@ -# mapper/util.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# orm/util.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sqlalchemy.exceptions as sa_exc -from sqlalchemy import sql, util -from sqlalchemy.sql import expression, util as sql_util, operators -from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, \ - MapperProperty, AttributeExtension -from sqlalchemy.orm import attributes, exc -mapperlib = None +from .. import sql, util, event, exc as sa_exc, inspection +from ..sql import expression, util as sql_util, operators +from .interfaces import PropComparator, MapperProperty +from . import attributes +import re + +from .base import instance_str, state_str, state_class_str, attribute_str, \ + state_attribute_str, object_mapper, object_state, _none_set, _never_set +from .base import class_mapper, _class_to_mapper +from .base import InspectionAttr +from .path_registry import PathRegistry all_cascades = frozenset(("delete", "delete-orphan", "all", "merge", "expunge", "save-update", "refresh-expire", "none")) -_INSTRUMENTOR = ('mapper', 'instrumentor') -class CascadeOptions(object): +class CascadeOptions(frozenset): """Keeps track of the options sent to relationship().cascade""" - def __init__(self, arg=""): - if not arg: - values = set() - else: - values = set(c.strip() for c in arg.split(',')) + _add_w_all_cascades = all_cascades.difference([ + 'all', 'none', 'delete-orphan']) + _allowed_cascades = all_cascades + + __slots__ = ( + 'save_update', 'delete', 'refresh_expire', 'merge', + 'expunge', 'delete_orphan') + + def __new__(cls, value_list): + if isinstance(value_list, util.string_types) or value_list is None: + return cls.from_string(value_list) + values = set(value_list) + if values.difference(cls._allowed_cascades): + raise sa_exc.ArgumentError( + "Invalid cascade option(s): %s" % + ", ".join([repr(x) for x in + sorted(values.difference(cls._allowed_cascades))])) + + if "all" in values: + values.update(cls._add_w_all_cascades) + if "none" in values: + values.clear() + values.discard('all') + + self = frozenset.__new__(CascadeOptions, values) + self.save_update = 'save-update' in values + self.delete = 'delete' in values + self.refresh_expire = 'refresh-expire' in values + self.merge = 'merge' in values + self.expunge = 'expunge' in values self.delete_orphan = "delete-orphan" in values - self.delete = "delete" in values or "all" in values - self.save_update = "save-update" in values or "all" in values - self.merge = "merge" in values or "all" in values - self.expunge = "expunge" in values or "all" in values - self.refresh_expire = "refresh-expire" in values or "all" in values if self.delete_orphan and not self.delete: - util.warn("The 'delete-orphan' cascade option requires " - "'delete'. This will raise an error in 0.6.") - - for x in values: - if x not in all_cascades: - raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x) - - def __contains__(self, item): - return getattr(self, item.replace("-", "_"), False) + util.warn("The 'delete-orphan' cascade " + "option requires 'delete'.") + return self def __repr__(self): - return "CascadeOptions(%s)" % repr(",".join( - [x for x in ['delete', 'save_update', 'merge', 'expunge', - 'delete_orphan', 'refresh-expire'] - if getattr(self, x, False) is True])) + return "CascadeOptions(%r)" % ( + ",".join([x for x in sorted(self)]) + ) + + @classmethod + def from_string(cls, arg): + values = [ + c for c + in re.split(r'\s*,\s*', arg or "") + if c + ] + return cls(values) -class Validator(AttributeExtension): - """Runs a validation method on an attribute value to be set or appended. - - The Validator class is used by the :func:`~sqlalchemy.orm.validates` - decorator, and direct access is usually not needed. - +def _validator_events( + desc, key, validator, include_removes, include_backrefs): + """Runs a validation method on an attribute value to be set or + appended. """ - def __init__(self, key, validator): - """Construct a new Validator. + if not include_backrefs: + def detect_is_backref(state, initiator): + impl = state.manager[key].impl + return initiator.impl is not impl - key - name of the attribute to be validated; - will be passed as the second argument to - the validation method (the first is the object instance itself). + if include_removes: + def append(state, value, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value, False) + else: + return value - validator - an function or instance method which accepts - three arguments; an instance (usually just 'self' for a method), - the key name of the attribute, and the value. The function should - return the same value given, unless it wishes to modify it. + def set_(state, value, oldvalue, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value, False) + else: + return value - """ - self.key = key - self.validator = validator + def remove(state, value, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + validator(state.obj(), key, value, True) - def append(self, state, value, initiator): - return self.validator(state.obj(), self.key, value) + else: + def append(state, value, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value) + else: + return value - def set(self, state, value, oldvalue, initiator): - return self.validator(state.obj(), self.key, value) + def set_(state, value, oldvalue, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value) + else: + return value -def polymorphic_union(table_map, typecolname, aliasname='p_union'): + event.listen(desc, 'append', append, raw=True, retval=True) + event.listen(desc, 'set', set_, raw=True, retval=True) + if include_removes: + event.listen(desc, "remove", remove, raw=True, retval=True) + + +def polymorphic_union(table_map, typecolname, + aliasname='p_union', cast_nulls=True): """Create a ``UNION`` statement used by a polymorphic mapper. See :ref:`concrete_inheritance` for an example of how this is used. + + :param table_map: mapping of polymorphic identities to + :class:`.Table` objects. + :param typecolname: string name of a "discriminator" column, which will be + derived from the query, producing the polymorphic identity for + each row. If ``None``, no polymorphic discriminator is generated. + :param aliasname: name of the :func:`~sqlalchemy.sql.expression.alias()` + construct generated. + :param cast_nulls: if True, non-existent columns, which are represented + as labeled NULLs, will be passed into CAST. This is a legacy behavior + that is problematic on some backends such as Oracle - in which case it + can be set to False. + """ - colnames = set() + colnames = util.OrderedSet() colnamemaps = {} types = {} - for key in table_map.keys(): + for key in table_map: table = table_map[key] - # mysql doesnt like selecting from a select; make it an alias of the select + # mysql doesn't like selecting from a select; + # make it an alias of the select if isinstance(table, sql.Select): table = table.alias() table_map[key] = table @@ -111,45 +169,81 @@ def polymorphic_union(table_map, typecolname, aliasname='p_union'): try: return colnamemaps[table][name] except KeyError: - return sql.cast(sql.null(), types[name]).label(name) + if cast_nulls: + return sql.cast(sql.null(), types[name]).label(name) + else: + return sql.type_coerce(sql.null(), types[name]).label(name) result = [] - for type, table in table_map.iteritems(): + for type, table in table_map.items(): if typecolname is not None: - result.append(sql.select([col(name, table) for name in colnames] + - [sql.literal_column("'%s'" % type).label(typecolname)], - from_obj=[table])) + result.append( + sql.select([col(name, table) for name in colnames] + + [sql.literal_column( + sql_util._quote_ddl_expr(type)). + label(typecolname)], + from_obj=[table])) else: result.append(sql.select([col(name, table) for name in colnames], from_obj=[table])) return sql.union_all(*result).alias(aliasname) -def identity_key(*args, **kwargs): - """Get an identity key. - Valid call signatures: +def identity_key(*args, **kwargs): + """Generate "identity key" tuples, as are used as keys in the + :attr:`.Session.identity_map` dictionary. + + This function has several call styles: * ``identity_key(class, ident)`` - class - mapped class (must be a positional argument) + This form receives a mapped class and a primary key scalar or + tuple as an argument. - ident - primary key, if the key is composite this is a tuple + E.g.:: + + >>> identity_key(MyClass, (1, 2)) + (, (1, 2)) + + :param class: mapped class (must be a positional argument) + :param ident: primary key, may be a scalar or tuple argument. * ``identity_key(instance=instance)`` - instance - object instance (must be given as a keyword arg) + This form will produce the identity key for a given instance. The + instance need not be persistent, only that its primary key attributes + are populated (else the key will contain ``None`` for those missing + values). + + E.g.:: + + >>> instance = MyClass(1, 2) + >>> identity_key(instance=instance) + (, (1, 2)) + + In this form, the given instance is ultimately run though + :meth:`.Mapper.identity_key_from_instance`, which will have the + effect of performing a database check for the corresponding row + if the object is expired. + + :param instance: object instance (must be given as a keyword arg) * ``identity_key(class, row=row)`` - class - mapped class (must be a positional argument) + This form is similar to the class/tuple form, except is passed a + database result row as a :class:`.RowProxy` object. - row - result proxy row (must be given as a keyword arg) + E.g.:: + + >>> row = engine.execute("select * from table where a=1 and b=2").\ +first() + >>> identity_key(MyClass, row=row) + (, (1, 2)) + + :param class: mapped class (must be a positional argument) + :param row: :class:`.RowProxy` row returned by a :class:`.ResultProxy` + (must be given as a keyword arg) """ if args: @@ -164,505 +258,801 @@ def identity_key(*args, **kwargs): elif len(args) == 3: class_, ident = args else: - raise sa_exc.ArgumentError("expected up to three " - "positional arguments, got %s" % len(args)) + raise sa_exc.ArgumentError( + "expected up to three positional arguments, " + "got %s" % len(args)) if kwargs: raise sa_exc.ArgumentError("unknown keyword arguments: %s" - % ", ".join(kwargs.keys())) + % ", ".join(kwargs)) mapper = class_mapper(class_) if "ident" in locals(): - return mapper.identity_key_from_primary_key(ident) + return mapper.identity_key_from_primary_key(util.to_list(ident)) return mapper.identity_key_from_row(row) instance = kwargs.pop("instance") if kwargs: raise sa_exc.ArgumentError("unknown keyword arguments: %s" - % ", ".join(kwargs.keys())) + % ", ".join(kwargs.keys)) mapper = object_mapper(instance) return mapper.identity_key_from_instance(instance) -class ExtensionCarrier(dict): - """Fronts an ordered collection of MapperExtension objects. - - Bundles multiple MapperExtensions into a unified callable unit, - encapsulating ordering, looping and EXT_CONTINUE logic. The - ExtensionCarrier implements the MapperExtension interface, e.g.:: - - carrier.after_insert(...args...) - - The dictionary interface provides containment for implemented - method names mapped to a callable which executes that method - for participating extensions. - - """ - - interface = set(method for method in dir(MapperExtension) - if not method.startswith('_')) - - def __init__(self, extensions=None): - self._extensions = [] - for ext in extensions or (): - self.append(ext) - - def copy(self): - return ExtensionCarrier(self._extensions) - - def push(self, extension): - """Insert a MapperExtension at the beginning of the collection.""" - self._register(extension) - self._extensions.insert(0, extension) - - def append(self, extension): - """Append a MapperExtension at the end of the collection.""" - self._register(extension) - self._extensions.append(extension) - - def __iter__(self): - """Iterate over MapperExtensions in the collection.""" - return iter(self._extensions) - - def _register(self, extension): - """Register callable fronts for overridden interface methods.""" - - for method in self.interface.difference(self): - impl = getattr(extension, method, None) - if impl and impl is not getattr(MapperExtension, method): - self[method] = self._create_do(method) - - def _create_do(self, method): - """Return a closure that loops over impls of the named method.""" - - def _do(*args, **kwargs): - for ext in self._extensions: - ret = getattr(ext, method)(*args, **kwargs) - if ret is not EXT_CONTINUE: - return ret - else: - return EXT_CONTINUE - _do.__name__ = method - return _do - - @staticmethod - def _pass(*args, **kwargs): - return EXT_CONTINUE - - def __getattr__(self, key): - """Delegate MapperExtension methods to bundled fronts.""" - - if key not in self.interface: - raise AttributeError(key) - return self.get(key, self._pass) class ORMAdapter(sql_util.ColumnAdapter): - """Extends ColumnAdapter to accept ORM entities. - - The selectable is extracted from the given entity, - and the AliasedClass if any is referenced. + """ColumnAdapter subclass which excludes adaptation of entities from + non-matching mappers. """ - def __init__(self, entity, equivalents=None, chain_to=None, adapt_required=False): - self.mapper, selectable, is_aliased_class = _entity_info(entity) + + def __init__(self, entity, equivalents=None, adapt_required=False, + chain_to=None, allow_label_resolve=True, + anonymize_labels=False): + info = inspection.inspect(entity) + + self.mapper = info.mapper + selectable = info.selectable + is_aliased_class = info.is_aliased_class if is_aliased_class: self.aliased_class = entity else: self.aliased_class = None - sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to, adapt_required=adapt_required) - def replace(self, elem): + sql_util.ColumnAdapter.__init__( + self, selectable, equivalents, chain_to, + adapt_required=adapt_required, + allow_label_resolve=allow_label_resolve, + anonymize_labels=anonymize_labels, + include_fn=self._include_fn + ) + + def _include_fn(self, elem): entity = elem._annotations.get('parentmapper', None) - if not entity or entity.isa(self.mapper): - return sql_util.ColumnAdapter.replace(self, elem) - else: - return None + return not entity or entity.isa(self.mapper) + class AliasedClass(object): - """Represents an "aliased" form of a mapped class for usage with Query. + r"""Represents an "aliased" form of a mapped class for usage with Query. The ORM equivalent of a :func:`sqlalchemy.sql.expression.alias` construct, this object mimics the mapped class using a __getattr__ scheme and maintains a reference to a - real :class:`~sqlalchemy.sql.expression.Alias` object. - - Usage is via the :class:`~sqlalchemy.orm.aliased()` synonym:: + real :class:`~sqlalchemy.sql.expression.Alias` object. + + Usage is via the :func:`.orm.aliased` function, or alternatively + via the :func:`.orm.with_polymorphic` function. + + Usage example:: # find all pairs of users with the same name user_alias = aliased(User) - session.query(User, user_alias).\\ - join((user_alias, User.id > user_alias.id)).\\ + session.query(User, user_alias).\ + join((user_alias, User.id > user_alias.id)).\ filter(User.name==user_alias.name) + The resulting object is an instance of :class:`.AliasedClass`. + This object implements an attribute scheme which produces the + same attribute and method interface as the original mapped + class, allowing :class:`.AliasedClass` to be compatible + with any attribute technique which works on the original class, + including hybrid attributes (see :ref:`hybrids_toplevel`). + + The :class:`.AliasedClass` can be inspected for its underlying + :class:`.Mapper`, aliased selectable, and other information + using :func:`.inspect`:: + + from sqlalchemy import inspect + my_alias = aliased(MyClass) + insp = inspect(my_alias) + + The resulting inspection object is an instance of :class:`.AliasedInsp`. + + See :func:`.aliased` and :func:`.with_polymorphic` for construction + argument descriptions. + """ - def __init__(self, cls, alias=None, name=None): - self.__mapper = _class_to_mapper(cls) - self.__target = self.__mapper.class_ + + def __init__(self, cls, alias=None, + name=None, + flat=False, + adapt_on_names=False, + # TODO: None for default here? + with_polymorphic_mappers=(), + with_polymorphic_discriminator=None, + base_alias=None, + use_mapper_path=False): + mapper = _class_to_mapper(cls) if alias is None: - alias = self.__mapper._with_polymorphic_selectable.alias() - self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns) - self.__alias = alias - # used to assign a name to the RowTuple object - # returned by Query. - self._sa_label_name = name - self.__name__ = 'AliasedClass_' + str(self.__target) + alias = mapper._with_polymorphic_selectable.alias( + name=name, flat=flat) - def __getstate__(self): - return {'mapper':self.__mapper, 'alias':self.__alias, 'name':self._sa_label_name} + self._aliased_insp = AliasedInsp( + self, + mapper, + alias, + name, + with_polymorphic_mappers + if with_polymorphic_mappers + else mapper.with_polymorphic_mappers, + with_polymorphic_discriminator + if with_polymorphic_discriminator is not None + else mapper.polymorphic_on, + base_alias, + use_mapper_path, + adapt_on_names + ) - def __setstate__(self, state): - self.__mapper = state['mapper'] - self.__target = self.__mapper.class_ - alias = state['alias'] - self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns) - self.__alias = alias - name = state['name'] - self._sa_label_name = name - self.__name__ = 'AliasedClass_' + str(self.__target) - - def __adapt_element(self, elem): - return self.__adapter.traverse(elem)._annotate({'parententity': self, 'parentmapper':self.__mapper}) - - def __adapt_prop(self, prop): - existing = getattr(self.__target, prop.key) - comparator = existing.comparator.adapted(self.__adapt_element) - - queryattr = attributes.QueryableAttribute(prop.key, - impl=existing.impl, parententity=self, comparator=comparator) - setattr(self, prop.key, queryattr) - return queryattr + self.__name__ = 'AliasedClass_%s' % mapper.class_.__name__ def __getattr__(self, key): - prop = self.__mapper._get_property(key, raiseerr=False) - if prop: - return self.__adapt_prop(prop) - - for base in self.__target.__mro__: - try: - attr = object.__getattribute__(base, key) - except AttributeError: - continue - else: - break + try: + _aliased_insp = self.__dict__['_aliased_insp'] + except KeyError: + raise AttributeError() else: - raise AttributeError(key) + for base in _aliased_insp._target.__mro__: + try: + attr = object.__getattribute__(base, key) + except AttributeError: + continue + else: + break + else: + raise AttributeError(key) - if hasattr(attr, 'func_code'): - is_method = getattr(self.__target, key, None) - if is_method and is_method.im_self is not None: - return util.types.MethodType(attr.im_func, self, self) + if isinstance(attr, PropComparator): + ret = attr.adapt_to_entity(_aliased_insp) + setattr(self, key, ret) + return ret + elif hasattr(attr, 'func_code'): + is_method = getattr(_aliased_insp._target, key, None) + if is_method and is_method.__self__ is not None: + return util.types.MethodType(attr.__func__, self, self) else: return None elif hasattr(attr, '__get__'): - return attr.__get__(None, self) + ret = attr.__get__(None, self) + if isinstance(ret, PropComparator): + return ret.adapt_to_entity(_aliased_insp) + else: + return ret else: return attr def __repr__(self): return '' % ( - id(self), self.__target.__name__) + id(self), self._aliased_insp._target.__name__) + + +class AliasedInsp(InspectionAttr): + """Provide an inspection interface for an + :class:`.AliasedClass` object. + + The :class:`.AliasedInsp` object is returned + given an :class:`.AliasedClass` using the + :func:`.inspect` function:: + + from sqlalchemy import inspect + from sqlalchemy.orm import aliased + + my_alias = aliased(MyMappedClass) + insp = inspect(my_alias) + + Attributes on :class:`.AliasedInsp` + include: + + * ``entity`` - the :class:`.AliasedClass` represented. + * ``mapper`` - the :class:`.Mapper` mapping the underlying class. + * ``selectable`` - the :class:`.Alias` construct which ultimately + represents an aliased :class:`.Table` or :class:`.Select` + construct. + * ``name`` - the name of the alias. Also is used as the attribute + name when returned in a result tuple from :class:`.Query`. + * ``with_polymorphic_mappers`` - collection of :class:`.Mapper` objects + indicating all those mappers expressed in the select construct + for the :class:`.AliasedClass`. + * ``polymorphic_on`` - an alternate column or SQL expression which + will be used as the "discriminator" for a polymorphic load. + + .. seealso:: + + :ref:`inspection_toplevel` + + """ + + def __init__(self, entity, mapper, selectable, name, + with_polymorphic_mappers, polymorphic_on, + _base_alias, _use_mapper_path, adapt_on_names): + self.entity = entity + self.mapper = mapper + self.selectable = selectable + self.name = name + self.with_polymorphic_mappers = with_polymorphic_mappers + self.polymorphic_on = polymorphic_on + self._base_alias = _base_alias or self + self._use_mapper_path = _use_mapper_path + + self._adapter = sql_util.ColumnAdapter( + selectable, equivalents=mapper._equivalent_columns, + adapt_on_names=adapt_on_names, anonymize_labels=True) + + self._adapt_on_names = adapt_on_names + self._target = mapper.class_ + + for poly in self.with_polymorphic_mappers: + if poly is not mapper: + setattr(self.entity, poly.class_.__name__, + AliasedClass(poly.class_, selectable, base_alias=self, + adapt_on_names=adapt_on_names, + use_mapper_path=_use_mapper_path)) + + is_aliased_class = True + "always returns True" + + @property + def class_(self): + """Return the mapped class ultimately represented by this + :class:`.AliasedInsp`.""" + return self.mapper.class_ + + @util.memoized_property + def _path_registry(self): + if self._use_mapper_path: + return self.mapper._path_registry + else: + return PathRegistry.per_mapper(self) + + def __getstate__(self): + return { + 'entity': self.entity, + 'mapper': self.mapper, + 'alias': self.selectable, + 'name': self.name, + 'adapt_on_names': self._adapt_on_names, + 'with_polymorphic_mappers': + self.with_polymorphic_mappers, + 'with_polymorphic_discriminator': + self.polymorphic_on, + 'base_alias': self._base_alias, + 'use_mapper_path': self._use_mapper_path + } + + def __setstate__(self, state): + self.__init__( + state['entity'], + state['mapper'], + state['alias'], + state['name'], + state['with_polymorphic_mappers'], + state['with_polymorphic_discriminator'], + state['base_alias'], + state['use_mapper_path'], + state['adapt_on_names'] + ) + + def _adapt_element(self, elem): + return self._adapter.traverse(elem).\ + _annotate({ + 'parententity': self, + 'parentmapper': self.mapper} + ) + + def _entity_for_mapper(self, mapper): + self_poly = self.with_polymorphic_mappers + if mapper in self_poly: + if mapper is self.mapper: + return self + else: + return getattr( + self.entity, mapper.class_.__name__)._aliased_insp + elif mapper.isa(self.mapper): + return self + else: + assert False, "mapper %s doesn't correspond to %s" % ( + mapper, self) + + @util.memoized_property + def _memoized_values(self): + return {} + + def _memo(self, key, callable_, *args, **kw): + if key in self._memoized_values: + return self._memoized_values[key] + else: + self._memoized_values[key] = value = callable_(*args, **kw) + return value + + def __repr__(self): + if self.with_polymorphic_mappers: + with_poly = "(%s)" % ", ".join( + mp.class_.__name__ for mp in self.with_polymorphic_mappers) + else: + with_poly = "" + return '' % ( + id(self), self.class_.__name__, with_poly) + + +inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) +inspection._inspects(AliasedInsp)(lambda target: target) + + +def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False): + """Produce an alias of the given element, usually an :class:`.AliasedClass` + instance. + + E.g.:: + + my_alias = aliased(MyClass) + + session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id) + + The :func:`.aliased` function is used to create an ad-hoc mapping + of a mapped class to a new selectable. By default, a selectable + is generated from the normally mapped selectable (typically a + :class:`.Table`) using the :meth:`.FromClause.alias` method. + However, :func:`.aliased` can also be used to link the class to + a new :func:`.select` statement. Also, the :func:`.with_polymorphic` + function is a variant of :func:`.aliased` that is intended to specify + a so-called "polymorphic selectable", that corresponds to the union + of several joined-inheritance subclasses at once. + + For convenience, the :func:`.aliased` function also accepts plain + :class:`.FromClause` constructs, such as a :class:`.Table` or + :func:`.select` construct. In those cases, the :meth:`.FromClause.alias` + method is called on the object and the new :class:`.Alias` object + returned. The returned :class:`.Alias` is not ORM-mapped in this case. + + :param element: element to be aliased. Is normally a mapped class, + but for convenience can also be a :class:`.FromClause` element. + + :param alias: Optional selectable unit to map the element to. This should + normally be a :class:`.Alias` object corresponding to the :class:`.Table` + to which the class is mapped, or to a :func:`.select` construct that + is compatible with the mapping. By default, a simple anonymous + alias of the mapped table is generated. + + :param name: optional string name to use for the alias, if not specified + by the ``alias`` parameter. The name, among other things, forms the + attribute name that will be accessible via tuples returned by a + :class:`.Query` object. + + :param flat: Boolean, will be passed through to the + :meth:`.FromClause.alias` call so that aliases of :class:`.Join` objects + don't include an enclosing SELECT. This can lead to more efficient + queries in many circumstances. A JOIN against a nested JOIN will be + rewritten as a JOIN against an aliased SELECT subquery on backends that + don't support this syntax. + + .. versionadded:: 0.9.0 + + .. seealso:: :meth:`.Join.alias` + + :param adapt_on_names: if True, more liberal "matching" will be used when + mapping the mapped columns of the ORM entity to those of the + given selectable - a name-based match will be performed if the + given selectable doesn't otherwise have a column that corresponds + to one on the entity. The use case for this is when associating + an entity with some derived selectable such as one that uses + aggregate functions:: + + class UnitPrice(Base): + __tablename__ = 'unit_price' + ... + unit_id = Column(Integer) + price = Column(Numeric) + + aggregated_unit_price = Session.query( + func.sum(UnitPrice.price).label('price') + ).group_by(UnitPrice.unit_id).subquery() + + aggregated_unit_price = aliased(UnitPrice, + alias=aggregated_unit_price, adapt_on_names=True) + + Above, functions on ``aggregated_unit_price`` which refer to + ``.price`` will return the + ``func.sum(UnitPrice.price).label('price')`` column, as it is + matched on the name "price". Ordinarily, the "price" function + wouldn't have any "column correspondence" to the actual + ``UnitPrice.price`` column as it is not a proxy of the original. + + .. versionadded:: 0.7.3 + + + """ + if isinstance(element, expression.FromClause): + if adapt_on_names: + raise sa_exc.ArgumentError( + "adapt_on_names only applies to ORM elements" + ) + return element.alias(name, flat=flat) + else: + return AliasedClass(element, alias=alias, flat=flat, + name=name, adapt_on_names=adapt_on_names) + + +def with_polymorphic(base, classes, selectable=False, + flat=False, + polymorphic_on=None, aliased=False, + innerjoin=False, _use_mapper_path=False, + _existing_alias=None): + """Produce an :class:`.AliasedClass` construct which specifies + columns for descendant mappers of the given base. + + .. versionadded:: 0.8 + :func:`.orm.with_polymorphic` is in addition to the existing + :class:`.Query` method :meth:`.Query.with_polymorphic`, + which has the same purpose but is not as flexible in its usage. + + Using this method will ensure that each descendant mapper's + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting + instances will also have those columns already loaded so that + no "post fetch" of those columns will be required. + + See the examples at :ref:`with_polymorphic`. + + :param base: Base class to be aliased. + + :param classes: a single class or mapper, or list of + class/mappers, which inherit from the base class. + Alternatively, it may also be the string ``'*'``, in which case + all descending mapped classes will be added to the FROM clause. + + :param aliased: when True, the selectable will be wrapped in an + alias, that is ``(SELECT * FROM ) AS anon_1``. + This can be important when using the with_polymorphic() + to create the target of a JOIN on a backend that does not + support parenthesized joins, such as SQLite and older + versions of MySQL. + + :param flat: Boolean, will be passed through to the + :meth:`.FromClause.alias` call so that aliases of :class:`.Join` + objects don't include an enclosing SELECT. This can lead to more + efficient queries in many circumstances. A JOIN against a nested JOIN + will be rewritten as a JOIN against an aliased SELECT subquery on + backends that don't support this syntax. + + Setting ``flat`` to ``True`` implies the ``aliased`` flag is + also ``True``. + + .. versionadded:: 0.9.0 + + .. seealso:: :meth:`.Join.alias` + + :param selectable: a table or select() statement that will + be used in place of the generated FROM clause. This argument is + required if any of the desired classes use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` argument + must represent the full set of tables and columns mapped by every + mapped class. Otherwise, the unaccounted mapped columns will + result in their table being appended directly to the FROM clause + which will usually lead to incorrect results. + + :param polymorphic_on: a column to be used as the "discriminator" + column for the given selectable. If not given, the polymorphic_on + attribute of the base classes' mapper will be used, if any. This + is useful for mappings that don't have polymorphic loading + behavior by default. + + :param innerjoin: if True, an INNER JOIN will be used. This should + only be specified if querying for one specific subtype only + """ + primary_mapper = _class_to_mapper(base) + if _existing_alias: + assert _existing_alias.mapper is primary_mapper + classes = util.to_set(classes) + new_classes = set([ + mp.class_ for mp in + _existing_alias.with_polymorphic_mappers]) + if classes == new_classes: + return _existing_alias + else: + classes = classes.union(new_classes) + mappers, selectable = primary_mapper.\ + _with_polymorphic_args(classes, selectable, + innerjoin=innerjoin) + if aliased or flat: + selectable = selectable.alias(flat=flat) + return AliasedClass(base, + selectable, + with_polymorphic_mappers=mappers, + with_polymorphic_discriminator=polymorphic_on, + use_mapper_path=_use_mapper_path) + def _orm_annotate(element, exclude=None): - """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag. + """Deep copy the given ClauseElement, annotating each element with the + "_orm_adapt" flag. Elements within the exclude collection will be cloned but not annotated. """ - return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude) + return sql_util._deep_annotate(element, {'_orm_adapt': True}, exclude) + + +def _orm_deannotate(element): + """Remove annotations that link a column to a particular mapping. + + Note this doesn't affect "remote" and "foreign" annotations + passed by the :func:`.orm.foreign` and :func:`.orm.remote` + annotators. + + """ + + return sql_util._deep_deannotate(element, + values=("_orm_adapt", "parententity") + ) + + +def _orm_full_deannotate(element): + return sql_util._deep_deannotate(element) -_orm_deannotate = sql_util._deep_deannotate class _ORMJoin(expression.Join): """Extend Join to support ORM constructs as input.""" __visit_name__ = expression.Join.__visit_name__ - def __init__(self, left, right, onclause=None, isouter=False, join_to_left=True): - adapt_from = None + def __init__( + self, + left, right, onclause=None, isouter=False, + full=False, _left_memo=None, _right_memo=None): - if hasattr(left, '_orm_mappers'): - left_mapper = left._orm_mappers[1] - if join_to_left: - adapt_from = left.right + left_info = inspection.inspect(left) + left_orm_info = getattr(left, '_joined_from_info', left_info) + + right_info = inspection.inspect(right) + adapt_to = right_info.selectable + + self._joined_from_info = right_info + + self._left_memo = _left_memo + self._right_memo = _right_memo + + if isinstance(onclause, util.string_types): + onclause = getattr(left_orm_info.entity, onclause) + + if isinstance(onclause, attributes.QueryableAttribute): + on_selectable = onclause.comparator._source_selectable() + prop = onclause.property + elif isinstance(onclause, MapperProperty): + prop = onclause + on_selectable = prop.parent.selectable else: - left_mapper, left, left_is_aliased = _entity_info(left) - if join_to_left and (left_is_aliased or not left_mapper): - adapt_from = left + prop = None - right_mapper, right, right_is_aliased = _entity_info(right) - if right_is_aliased: - adapt_to = right - else: - adapt_to = None - - if left_mapper or right_mapper: - self._orm_mappers = (left_mapper, right_mapper) - - if isinstance(onclause, basestring): - prop = left_mapper.get_property(onclause) - elif isinstance(onclause, attributes.QueryableAttribute): - if adapt_from is None: - adapt_from = onclause.__clause_element__() - prop = onclause.property - elif isinstance(onclause, MapperProperty): - prop = onclause + if prop: + if sql_util.clause_is_present( + on_selectable, left_info.selectable): + adapt_from = on_selectable else: - prop = None + adapt_from = left_info.selectable - if prop: - pj, sj, source, dest, secondary, target_adapter = prop._create_joins( - source_selectable=adapt_from, - dest_selectable=adapt_to, - source_polymorphic=True, - dest_polymorphic=True, - of_type=right_mapper) + pj, sj, source, dest, \ + secondary, target_adapter = prop._create_joins( + source_selectable=adapt_from, + dest_selectable=adapt_to, + source_polymorphic=True, + dest_polymorphic=True, + of_type=right_info.mapper) - if sj is not None: + if sj is not None: + if isouter: + # note this is an inner join from secondary->right + right = sql.join(secondary, right, sj) + onclause = pj + else: left = sql.join(left, secondary, pj, isouter) onclause = sj - else: - onclause = pj - self._target_adapter = target_adapter + else: + onclause = pj + self._target_adapter = target_adapter - expression.Join.__init__(self, left, right, onclause, isouter) + expression.Join.__init__(self, left, right, onclause, isouter, full) - def join(self, right, onclause=None, isouter=False, join_to_left=True): - return _ORMJoin(self, right, onclause, isouter, join_to_left) + if not prop and getattr(right_info, 'mapper', None) \ + and right_info.mapper.single: + # if single inheritance target and we are using a manual + # or implicit ON clause, augment it the same way we'd augment the + # WHERE. + single_crit = right_info.mapper._single_table_criterion + if single_crit is not None: + if right_info.is_aliased_class: + single_crit = right_info._adapter.traverse(single_crit) + self.onclause = self.onclause & single_crit - def outerjoin(self, right, onclause=None, join_to_left=True): - return _ORMJoin(self, right, onclause, True, join_to_left) + def _splice_into_center(self, other): + """Splice a join into the center. -def join(left, right, onclause=None, isouter=False, join_to_left=True): - """Produce an inner join between left and right clauses. + Given join(a, b) and join(b, c), return join(a, b).join(c) - In addition to the interface provided by - :func:`~sqlalchemy.sql.expression.join()`, left and right may be mapped - classes or AliasedClass instances. The onclause may be a - string name of a relationship(), or a class-bound descriptor - representing a relationship. + """ + leftmost = other + while isinstance(leftmost, sql.Join): + leftmost = leftmost.left - join_to_left indicates to attempt aliasing the ON clause, - in whatever form it is passed, to the selectable - passed as the left side. If False, the onclause - is used as is. + assert self.right is leftmost + + left = _ORMJoin( + self.left, other.left, + self.onclause, isouter=self.isouter, + _left_memo=self._left_memo, + _right_memo=other._left_memo + ) + + return _ORMJoin( + left, + other.right, + other.onclause, isouter=other.isouter, + _right_memo=other._right_memo + ) + + def join( + self, right, onclause=None, + isouter=False, full=False, join_to_left=None): + return _ORMJoin(self, right, onclause, full, isouter) + + def outerjoin( + self, right, onclause=None, + full=False, join_to_left=None): + return _ORMJoin(self, right, onclause, True, full=full) + + +def join( + left, right, onclause=None, isouter=False, + full=False, join_to_left=None): + r"""Produce an inner join between left and right clauses. + + :func:`.orm.join` is an extension to the core join interface + provided by :func:`.sql.expression.join()`, where the + left and right selectables may be not only core selectable + objects such as :class:`.Table`, but also mapped classes or + :class:`.AliasedClass` instances. The "on" clause can + be a SQL expression, or an attribute or string name + referencing a configured :func:`.relationship`. + + :func:`.orm.join` is not commonly needed in modern usage, + as its functionality is encapsulated within that of the + :meth:`.Query.join` method, which features a + significant amount of automation beyond :func:`.orm.join` + by itself. Explicit usage of :func:`.orm.join` + with :class:`.Query` involves usage of the + :meth:`.Query.select_from` method, as in:: + + from sqlalchemy.orm import join + session.query(User).\ + select_from(join(User, Address, User.addresses)).\ + filter(Address.email_address=='foo@bar.com') + + In modern SQLAlchemy the above join can be written more + succinctly as:: + + session.query(User).\ + join(User.addresses).\ + filter(Address.email_address=='foo@bar.com') + + See :meth:`.Query.join` for information on modern usage + of ORM level joins. + + .. versionchanged:: 0.8.1 - the ``join_to_left`` parameter + is no longer used, and is deprecated. """ - return _ORMJoin(left, right, onclause, isouter, join_to_left) + return _ORMJoin(left, right, onclause, isouter, full) -def outerjoin(left, right, onclause=None, join_to_left=True): + +def outerjoin(left, right, onclause=None, full=False, join_to_left=None): """Produce a left outer join between left and right clauses. - In addition to the interface provided by - :func:`~sqlalchemy.sql.expression.outerjoin()`, left and right may be mapped - classes or AliasedClass instances. The onclause may be a - string name of a relationship(), or a class-bound descriptor - representing a relationship. + This is the "outer join" version of the :func:`.orm.join` function, + featuring the same behavior except that an OUTER JOIN is generated. + See that function's documentation for other usage details. """ - return _ORMJoin(left, right, onclause, True, join_to_left) + return _ORMJoin(left, right, onclause, True, full) + def with_parent(instance, prop): - """Return criterion which selects instances with a given parent. + """Create filtering criterion that relates this query's primary entity + to the given related instance, using established :func:`.relationship()` + configuration. - instance - a parent instance, which should be persistent or detached. + The SQL rendered is the same as that rendered when a lazy loader + would fire off from the given parent on that attribute, meaning + that the appropriate state is taken from the parent object in + Python without the need to render joins to the parent table + in the rendered statement. - property - a class-attached descriptor, MapperProperty or string property name - attached to the parent instance. + .. versionchanged:: 0.6.4 + This method accepts parent instances in all + persistence states, including transient, persistent, and detached. + Only the requisite primary key/foreign key attributes need to + be populated. Previous versions didn't work with transient + instances. - \**kwargs - all extra keyword arguments are propagated to the constructor of - Query. + :param instance: + An instance which has some :func:`.relationship`. + + :param property: + String property name, or class-bound attribute, which indicates + what relationship from the instance should be used to reconcile the + parent/child relationship. """ - if isinstance(prop, basestring): + if isinstance(prop, util.string_types): mapper = object_mapper(instance) - prop = mapper.get_property(prop, resolve_synonyms=True) + prop = getattr(mapper.class_, prop).property elif isinstance(prop, attributes.QueryableAttribute): prop = prop.property - return prop.compare(operators.eq, instance, value_is_parent=True) + return prop._with_parent(instance) -def _entity_info(entity, compile=True): - """Return mapping information given a class, mapper, or AliasedClass. - - Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this - is an aliased() construct. - - If the given entity is not a mapper, mapped class, or aliased construct, - returns None, the entity, False. This is typically used to allow - unmapped selectables through. - - """ - if isinstance(entity, AliasedClass): - return entity._AliasedClass__mapper, entity._AliasedClass__alias, True - - global mapperlib - if mapperlib is None: - from sqlalchemy.orm import mapperlib - - if isinstance(entity, mapperlib.Mapper): - mapper = entity - - elif isinstance(entity, type): - class_manager = attributes.manager_of_class(entity) - - if class_manager is None: - return None, entity, False - - mapper = class_manager.mapper - else: - return None, entity, False - - if compile: - mapper = mapper.compile() - return mapper, mapper._with_polymorphic_selectable, False - -def _entity_descriptor(entity, key): - """Return attribute/property information given an entity and string name. - - Returns a 2-tuple representing InstrumentedAttribute/MapperProperty. - - """ - if isinstance(entity, AliasedClass): - try: - desc = getattr(entity, key) - return desc, desc.property - except AttributeError: - raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key)) - - elif isinstance(entity, type): - try: - desc = attributes.manager_of_class(entity)[key] - return desc, desc.property - except KeyError: - raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key)) - - else: - try: - desc = entity.class_manager[key] - return desc, desc.property - except KeyError: - raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key)) - -def _orm_columns(entity): - mapper, selectable, is_aliased_class = _entity_info(entity) - if isinstance(selectable, expression.Selectable): - return [c for c in selectable.c] - else: - return [selectable] - -def _orm_selectable(entity): - mapper, selectable, is_aliased_class = _entity_info(entity) - return selectable - -def _is_aliased_class(entity): - return isinstance(entity, AliasedClass) - -def _state_mapper(state): - return state.manager.mapper - -def object_mapper(instance): - """Given an object, return the primary Mapper associated with the object instance. - - Raises UnmappedInstanceError if no mapping is configured. - - """ - try: - state = attributes.instance_state(instance) - if not state.manager.mapper: - raise exc.UnmappedInstanceError(instance) - return state.manager.mapper - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) - -def class_mapper(class_, compile=True): - """Given a class, return the primary Mapper associated with the key. - - Raises UnmappedClassError if no mapping is configured. - - """ - try: - class_manager = attributes.manager_of_class(class_) - mapper = class_manager.mapper - - # HACK until [ticket:1142] is complete - if mapper is None: - raise AttributeError - - except exc.NO_STATE: - raise exc.UnmappedClassError(class_) - - if compile: - mapper = mapper.compile() - return mapper - -def _class_to_mapper(class_or_mapper, compile=True): - if _is_aliased_class(class_or_mapper): - return class_or_mapper._AliasedClass__mapper - elif isinstance(class_or_mapper, type): - return class_mapper(class_or_mapper, compile=compile) - elif hasattr(class_or_mapper, 'compile'): - if compile: - return class_or_mapper.compile() - else: - return class_or_mapper - else: - raise exc.UnmappedClassError(class_or_mapper) - def has_identity(object): + """Return True if the given object has a database + identity. + + This typically corresponds to the object being + in either the persistent or detached state. + + .. seealso:: + + :func:`.was_deleted` + + """ state = attributes.instance_state(object) - return _state_has_identity(state) - -def _state_has_identity(state): - return bool(state.key) - -def _is_mapped_class(cls): - global mapperlib - if mapperlib is None: - from sqlalchemy.orm import mapperlib - if isinstance(cls, (AliasedClass, mapperlib.Mapper)): - return True - if isinstance(cls, expression.ClauseElement): - return False - if isinstance(cls, type): - manager = attributes.manager_of_class(cls) - return manager and _INSTRUMENTOR in manager.info - return False - -def instance_str(instance): - """Return a string describing an instance.""" - - return state_str(attributes.instance_state(instance)) - -def state_str(state): - """Return a string describing an instance via its InstanceState.""" - - if state is None: - return "None" - else: - return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj())) - -def attribute_str(instance, attribute): - return instance_str(instance) + "." + attribute - -def state_attribute_str(state, attribute): - return state_str(state) + "." + attribute - -def identity_equal(a, b): - if a is b: - return True - if a is None or b is None: - return False - try: - state_a = attributes.instance_state(a) - state_b = attributes.instance_state(b) - except exc.NO_STATE: - return False - if state_a.key is None or state_b.key is None: - return False - return state_a.key == state_b.key + return state.has_identity -# TODO: Avoid circular import. -attributes.identity_equal = identity_equal -attributes._is_aliased_class = _is_aliased_class -attributes._entity_info = _entity_info +def was_deleted(object): + """Return True if the given object was deleted + within a session flush. + + This is regardless of whether or not the object is + persistent or detached. + + .. versionadded:: 0.8.0 + + .. seealso:: + + :attr:`.InstanceState.was_deleted` + + """ + + state = attributes.instance_state(object) + return state.was_deleted + + +def randomize_unitofwork(): + """Use random-ordering sets within the unit of work in order + to detect unit of work sorting issues. + + This is a utility function that can be used to help reproduce + inconsistent unit of work sorting issues. For example, + if two kinds of objects A and B are being inserted, and + B has a foreign key reference to A - the A must be inserted first. + However, if there is no relationship between A and B, the unit of work + won't know to perform this sorting, and an operation may or may not + fail, depending on how the ordering works out. Since Python sets + and dictionaries have non-deterministic ordering, such an issue may + occur on some runs and not on others, and in practice it tends to + have a great dependence on the state of the interpreter. This leads + to so-called "heisenbugs" where changing entirely irrelevant aspects + of the test program still cause the failure behavior to change. + + By calling ``randomize_unitofwork()`` when a script first runs, the + ordering of a key series of sets within the unit of work implementation + are randomized, so that the script can be minimized down to the + fundamental mapping and operation that's failing, while still reproducing + the issue on at least some runs. + + This utility is also available when running the test suite via the + ``--reversetop`` flag. + + .. versionadded:: 0.8.1 created a standalone version of the + ``--reversetop`` feature. + + """ + from sqlalchemy.orm import unitofwork, session, mapper, dependency + from sqlalchemy.util import topological + from sqlalchemy.testing.util import RandomSet + topological.set = unitofwork.set = session.set = mapper.set = \ + dependency.set = RandomSet diff --git a/sqlalchemy/pool.py b/sqlalchemy/pool.py index 31ab7fa..b58fdaa 100644 --- a/sqlalchemy/pool.py +++ b/sqlalchemy/pool.py @@ -1,5 +1,6 @@ -# pool.py - Connection pooling for SQLAlchemy -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# sqlalchemy/pool.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -16,16 +17,21 @@ regular DB-API connect() methods to be transparently managed by a SQLAlchemy connection pool. """ -import weakref, time, threading +import time +import traceback +import weakref -from sqlalchemy import exc, log -from sqlalchemy import queue as sqla_queue -from sqlalchemy.util import threading, pickle, as_interface, memoized_property +from . import exc, log, event, interfaces, util +from .util import queue as sqla_queue +from .util import threading, memoized_property, \ + chop_traceback +from collections import deque proxies = {} + def manage(module, **params): - """Return a proxy for a DB-API module that automatically + r"""Return a proxy for a DB-API module that automatically pools connections. Given a DB-API 2.0 module and pool management parameters, returns @@ -36,9 +42,9 @@ def manage(module, **params): :param module: a DB-API 2.0 database module :param poolclass: the class used by the pool module to provide - pooling. Defaults to :class:`QueuePool`. + pooling. Defaults to :class:`.QueuePool`. - :param \*\*params: will be passed through to *poolclass* + :param \**params: will be passed through to *poolclass* """ try: @@ -46,24 +52,58 @@ def manage(module, **params): except KeyError: return proxies.setdefault(module, _DBProxy(module, **params)) + def clear_managers(): """Remove all current DB-API 2.0 managers. All pools and connections are disposed. """ - for manager in proxies.itervalues(): + for manager in proxies.values(): manager.close() proxies.clear() +reset_rollback = util.symbol('reset_rollback') +reset_commit = util.symbol('reset_commit') +reset_none = util.symbol('reset_none') + + +class _ConnDialect(object): + + """partial implementation of :class:`.Dialect` + which provides DBAPI connection methods. + + When a :class:`.Pool` is combined with an :class:`.Engine`, + the :class:`.Engine` replaces this with its own + :class:`.Dialect`. + + """ + + def do_rollback(self, dbapi_connection): + dbapi_connection.rollback() + + def do_commit(self, dbapi_connection): + dbapi_connection.commit() + + def do_close(self, dbapi_connection): + dbapi_connection.close() + + class Pool(log.Identified): + """Abstract base class for connection pools.""" - def __init__(self, - creator, recycle=-1, echo=None, - use_threadlocal=False, - logging_name=None, - reset_on_return=True, listeners=None): + _dialect = _ConnDialect() + + def __init__(self, + creator, recycle=-1, echo=None, + use_threadlocal=False, + logging_name=None, + reset_on_return=True, + listeners=None, + events=None, + dialect=None, + _dispatch=None): """ Construct a Pool. @@ -77,8 +117,8 @@ class Pool(log.Identified): replaced with a newly opened connection. Defaults to -1. :param logging_name: String identifier which will be used within - the "name" field of logging records generated within the - "sqlalchemy.pool" logger. Defaults to a hexstring of the object's + the "name" field of logging records generated within the + "sqlalchemy.pool" logger. Defaults to a hexstring of the object's id. :param echo: If True, connections being pulled and retrieved @@ -93,48 +133,232 @@ class Pool(log.Identified): already been retrieved from the pool and has not been returned yet. Offers a slight performance advantage at the cost of individual transactions by default. The - :meth:`unique_connection` method is provided to bypass the - threadlocal behavior installed into :meth:`connect`. + :meth:`.Pool.unique_connection` method is provided to return + a consistently unique connection to bypass this behavior + when the flag is set. - :param reset_on_return: If true, reset the database state of - connections returned to the pool. This is typically a - ROLLBACK to release locks and transaction resources. - Disable at your own peril. Defaults to True. + .. warning:: The :paramref:`.Pool.use_threadlocal` flag + **does not affect the behavior** of :meth:`.Engine.connect`. + :meth:`.Engine.connect` makes use of the + :meth:`.Pool.unique_connection` method which **does not use thread + local context**. To produce a :class:`.Connection` which refers + to the :meth:`.Pool.connect` method, use + :meth:`.Engine.contextual_connect`. - :param listeners: A list of + Note that other SQLAlchemy connectivity systems such as + :meth:`.Engine.execute` as well as the orm + :class:`.Session` make use of + :meth:`.Engine.contextual_connect` internally, so these functions + are compatible with the :paramref:`.Pool.use_threadlocal` setting. + + .. seealso:: + + :ref:`threadlocal_strategy` - contains detail on the + "threadlocal" engine strategy, which provides a more comprehensive + approach to "threadlocal" connectivity for the specific + use case of using :class:`.Engine` and :class:`.Connection` objects + directly. + + :param reset_on_return: Determine steps to take on + connections as they are returned to the pool. + reset_on_return can have any of these values: + + * ``"rollback"`` - call rollback() on the connection, + to release locks and transaction resources. + This is the default value. The vast majority + of use cases should leave this value set. + * ``True`` - same as 'rollback', this is here for + backwards compatibility. + * ``"commit"`` - call commit() on the connection, + to release locks and transaction resources. + A commit here may be desirable for databases that + cache query plans if a commit is emitted, + such as Microsoft SQL Server. However, this + value is more dangerous than 'rollback' because + any data changes present on the transaction + are committed unconditionally. + * ``None`` - don't do anything on the connection. + This setting should only be made on a database + that has no transaction support at all, + namely MySQL MyISAM. By not doing anything, + performance can be improved. This + setting should **never be selected** for a + database that supports transactions, + as it will lead to deadlocks and stale + state. + * ``"none"`` - same as ``None`` + + .. versionadded:: 0.9.10 + + * ``False`` - same as None, this is here for + backwards compatibility. + + .. versionchanged:: 0.7.6 + :paramref:`.Pool.reset_on_return` accepts ``"rollback"`` + and ``"commit"`` arguments. + + :param events: a list of 2-tuples, each of the form + ``(callable, target)`` which will be passed to :func:`.event.listen` + upon construction. Provided here so that event listeners + can be assigned via :func:`.create_engine` before dialect-level + listeners are applied. + + :param listeners: Deprecated. A list of :class:`~sqlalchemy.interfaces.PoolListener`-like objects or dictionaries of callables that receive events when DB-API connections are created, checked out and checked in to the - pool. + pool. This has been superseded by + :func:`~sqlalchemy.event.listen`. + + :param dialect: a :class:`.Dialect` that will handle the job + of calling rollback(), close(), or commit() on DBAPI connections. + If omitted, a built-in "stub" dialect is used. Applications that + make use of :func:`~.create_engine` should not use this parameter + as it is handled by the engine creation strategy. + + .. versionadded:: 1.1 - ``dialect`` is now a public parameter + to the :class:`.Pool`. """ if logging_name: - self.logging_name = logging_name - self.logger = log.instance_logger(self, echoflag=echo) + self.logging_name = self._orig_logging_name = logging_name + else: + self._orig_logging_name = None + + log.instance_logger(self, echoflag=echo) self._threadconns = threading.local() self._creator = creator self._recycle = recycle + self._invalidate_time = 0 self._use_threadlocal = use_threadlocal - self._reset_on_return = reset_on_return - self.echo = echo - self.listeners = [] - self._on_connect = [] - self._on_first_connect = [] - self._on_checkout = [] - self._on_checkin = [] + if reset_on_return in ('rollback', True, reset_rollback): + self._reset_on_return = reset_rollback + elif reset_on_return in ('none', None, False, reset_none): + self._reset_on_return = reset_none + elif reset_on_return in ('commit', reset_commit): + self._reset_on_return = reset_commit + else: + raise exc.ArgumentError( + "Invalid value for 'reset_on_return': %r" + % reset_on_return) + self.echo = echo + + if _dispatch: + self.dispatch._update(_dispatch, only_propagate=False) + if dialect: + self._dialect = dialect + if events: + for fn, target in events: + event.listen(self, target, fn) if listeners: + util.warn_deprecated( + "The 'listeners' argument to Pool (and " + "create_engine()) is deprecated. Use event.listen().") for l in listeners: self.add_listener(l) - def unique_connection(self): - return _ConnectionFairy(self).checkout() + @property + def _creator(self): + return self.__dict__['_creator'] + + @_creator.setter + def _creator(self, creator): + self.__dict__['_creator'] = creator + self._invoke_creator = self._should_wrap_creator(creator) + + def _should_wrap_creator(self, creator): + """Detect if creator accepts a single argument, or is sent + as a legacy style no-arg function. + + """ + + try: + argspec = util.get_callable_argspec(self._creator, no_self=True) + except TypeError: + return lambda crec: creator() + + defaulted = argspec[3] is not None and len(argspec[3]) or 0 + positionals = len(argspec[0]) - defaulted + + # look for the exact arg signature that DefaultStrategy + # sends us + if (argspec[0], argspec[3]) == (['connection_record'], (None,)): + return creator + # or just a single positional + elif positionals == 1: + return creator + # all other cases, just wrap and assume legacy "creator" callable + # thing + else: + return lambda crec: creator() + + def _close_connection(self, connection): + self.logger.debug("Closing connection %r", connection) + + try: + self._dialect.do_close(connection) + except Exception: + self.logger.error("Exception closing connection %r", + connection, exc_info=True) + + @util.deprecated( + 2.7, "Pool.add_listener is deprecated. Use event.listen()") + def add_listener(self, listener): + """Add a :class:`.PoolListener`-like object to this pool. + + ``listener`` may be an object that implements some or all of + PoolListener, or a dictionary of callables containing implementations + of some or all of the named methods in PoolListener. + + """ + interfaces.PoolListener._adapt_listener(self, listener) + + def unique_connection(self): + """Produce a DBAPI connection that is not referenced by any + thread-local context. + + This method is equivalent to :meth:`.Pool.connect` when the + :paramref:`.Pool.use_threadlocal` flag is not set to True. + When :paramref:`.Pool.use_threadlocal` is True, the + :meth:`.Pool.unique_connection` method provides a means of bypassing + the threadlocal context. + + """ + return _ConnectionFairy._checkout(self) + + def _create_connection(self): + """Called by subclasses to create a new ConnectionRecord.""" - def create_connection(self): return _ConnectionRecord(self) + def _invalidate(self, connection, exception=None): + """Mark all connections established within the generation + of the given connection as invalidated. + + If this pool's last invalidate time is before when the given + connection was created, update the timestamp til now. Otherwise, + no action is performed. + + Connections with a start time prior to this pool's invalidation + time will be recycled upon next checkout. + """ + + rec = getattr(connection, "_connection_record", None) + if not rec or self._invalidate_time < rec.starttime: + self._invalidate_time = time.time() + if getattr(connection, 'is_valid', False): + connection.invalidate(exception) + def recreate(self): - """Return a new instance with identical creation arguments.""" + """Return a new :class:`.Pool`, of the same class as this one + and configured with identical creation arguments. + + This method is used in conjunction with :meth:`dispose` + to close out an entire :class:`.Pool` and create a new one in + its place. + + """ raise NotImplementedError() @@ -142,189 +366,477 @@ class Pool(log.Identified): """Dispose of this pool. This method leaves the possibility of checked-out connections - remaining open, It is advised to not reuse the pool once dispose() - is called, and to instead use a new pool constructed by the - recreate() method. + remaining open, as it only affects connections that are + idle in the pool. + + See also the :meth:`Pool.recreate` method. + """ raise NotImplementedError() def connect(self): + """Return a DBAPI connection from the pool. + + The connection is instrumented such that when its + ``close()`` method is called, the connection will be returned to + the pool. + + """ if not self._use_threadlocal: - return _ConnectionFairy(self).checkout() + return _ConnectionFairy._checkout(self) try: rec = self._threadconns.current() - if rec: - return rec.checkout() except AttributeError: pass + else: + if rec is not None: + return rec._checkout_existing() - agent = _ConnectionFairy(self) - self._threadconns.current = weakref.ref(agent) - return agent.checkout() + return _ConnectionFairy._checkout(self, self._threadconns) - def return_conn(self, record): - if self._use_threadlocal and hasattr(self._threadconns, "current"): - del self._threadconns.current - self.do_return_conn(record) + def _return_conn(self, record): + """Given a _ConnectionRecord, return it to the :class:`.Pool`. - def get(self): - return self.do_get() + This method is called when an instrumented DBAPI connection + has its ``close()`` method called. + + """ + if self._use_threadlocal: + try: + del self._threadconns.current + except AttributeError: + pass + self._do_return_conn(record) + + def _do_get(self): + """Implementation for :meth:`get`, supplied by subclasses.""" - def do_get(self): raise NotImplementedError() - def do_return_conn(self, conn): + def _do_return_conn(self, conn): + """Implementation for :meth:`return_conn`, supplied by subclasses.""" + raise NotImplementedError() def status(self): raise NotImplementedError() - def add_listener(self, listener): - """Add a ``PoolListener``-like object to this pool. - - ``listener`` may be an object that implements some or all of - PoolListener, or a dictionary of callables containing implementations - of some or all of the named methods in PoolListener. - - """ - - listener = as_interface(listener, - methods=('connect', 'first_connect', 'checkout', 'checkin')) - - self.listeners.append(listener) - if hasattr(listener, 'connect'): - self._on_connect.append(listener) - if hasattr(listener, 'first_connect'): - self._on_first_connect.append(listener) - if hasattr(listener, 'checkout'): - self._on_checkout.append(listener) - if hasattr(listener, 'checkin'): - self._on_checkin.append(listener) class _ConnectionRecord(object): - def __init__(self, pool): + + """Internal object which maintains an individual DBAPI connection + referenced by a :class:`.Pool`. + + The :class:`._ConnectionRecord` object always exists for any particular + DBAPI connection whether or not that DBAPI connection has been + "checked out". This is in contrast to the :class:`._ConnectionFairy` + which is only a public facade to the DBAPI connection while it is checked + out. + + A :class:`._ConnectionRecord` may exist for a span longer than that + of a single DBAPI connection. For example, if the + :meth:`._ConnectionRecord.invalidate` + method is called, the DBAPI connection associated with this + :class:`._ConnectionRecord` + will be discarded, but the :class:`._ConnectionRecord` may be used again, + in which case a new DBAPI connection is produced when the :class:`.Pool` + next uses this record. + + The :class:`._ConnectionRecord` is delivered along with connection + pool events, including :meth:`.PoolEvents.connect` and + :meth:`.PoolEvents.checkout`, however :class:`._ConnectionRecord` still + remains an internal object whose API and internals may change. + + .. seealso:: + + :class:`._ConnectionFairy` + + """ + + def __init__(self, pool, connect=True): self.__pool = pool - self.connection = self.__connect() - self.info = {} - ls = pool.__dict__.pop('_on_first_connect', None) - if ls is not None: - for l in ls: - l.first_connect(self.connection, self) - if pool._on_connect: - for l in pool._on_connect: - l.connect(self.connection, self) + if connect: + self.__connect(first_connect_check=True) + self.finalize_callback = deque() + + fairy_ref = None + + starttime = None + + connection = None + """A reference to the actual DBAPI connection being tracked. + + May be ``None`` if this :class:`._ConnectionRecord` has been marked + as invalidated; a new DBAPI connection may replace it if the owning + pool calls upon this :class:`._ConnectionRecord` to reconnect. + + """ + + _soft_invalidate_time = 0 + + @util.memoized_property + def info(self): + """The ``.info`` dictionary associated with the DBAPI connection. + + This dictionary is shared among the :attr:`._ConnectionFairy.info` + and :attr:`.Connection.info` accessors. + + .. note:: + + The lifespan of this dictionary is linked to the + DBAPI connection itself, meaning that it is **discarded** each time + the DBAPI connection is closed and/or invalidated. The + :attr:`._ConnectionRecord.record_info` dictionary remains + persistent throughout the lifespan of the + :class:`._ConnectionRecord` container. + + """ + return {} + + @util.memoized_property + def record_info(self): + """An "info' dictionary associated with the connection record + itself. + + Unlike the :attr:`._ConnectionRecord.info` dictionary, which is linked + to the lifespan of the DBAPI connection, this dictionary is linked + to the lifespan of the :class:`._ConnectionRecord` container itself + and will remain persisent throughout the life of the + :class:`._ConnectionRecord`. + + .. versionadded:: 1.1 + + """ + return {} + + @classmethod + def checkout(cls, pool): + rec = pool._do_get() + try: + dbapi_connection = rec.get_connection() + except: + with util.safe_reraise(): + rec.checkin() + echo = pool._should_log_debug() + fairy = _ConnectionFairy(dbapi_connection, rec, echo) + rec.fairy_ref = weakref.ref( + fairy, + lambda ref: _finalize_fairy and + _finalize_fairy( + dbapi_connection, + rec, pool, ref, echo) + ) + _refs.add(rec) + if echo: + pool.logger.debug("Connection %r checked out from pool", + dbapi_connection) + return fairy + + def checkin(self): + self.fairy_ref = None + connection = self.connection + pool = self.__pool + while self.finalize_callback: + finalizer = self.finalize_callback.pop() + finalizer(connection) + if pool.dispatch.checkin: + pool.dispatch.checkin(connection, self) + pool._return_conn(self) + + @property + def in_use(self): + return self.fairy_ref is not None + + @property + def last_connect_time(self): + return self.starttime def close(self): if self.connection is not None: - self.__pool.logger.debug("Closing connection %r", self.connection) - try: - self.connection.close() - except (SystemExit, KeyboardInterrupt): - raise - except: - self.__pool.logger.debug("Exception closing connection %r", - self.connection) + self.__close() - def invalidate(self, e=None): - if e is not None: - self.__pool.logger.info("Invalidate connection %r (reason: %s:%s)", - self.connection, e.__class__.__name__, e) + def invalidate(self, e=None, soft=False): + """Invalidate the DBAPI connection held by this :class:`._ConnectionRecord`. + + This method is called for all connection invalidations, including + when the :meth:`._ConnectionFairy.invalidate` or + :meth:`.Connection.invalidate` methods are called, as well as when any + so-called "automatic invalidation" condition occurs. + + :param e: an exception object indicating a reason for the invalidation. + + :param soft: if True, the connection isn't closed; instead, this + connection will be recycled on next checkout. + + .. versionadded:: 1.0.3 + + .. seealso:: + + :ref:`pool_connection_invalidation` + + """ + # already invalidated + if self.connection is None: + return + if soft: + self.__pool.dispatch.soft_invalidate(self.connection, self, e) else: - self.__pool.logger.info("Invalidate connection %r", self.connection) - self.__close() - self.connection = None + self.__pool.dispatch.invalidate(self.connection, self, e) + if e is not None: + self.__pool.logger.info( + "%sInvalidate connection %r (reason: %s:%s)", + "Soft " if soft else "", + self.connection, e.__class__.__name__, e) + else: + self.__pool.logger.info( + "%sInvalidate connection %r", + "Soft " if soft else "", + self.connection) + if soft: + self._soft_invalidate_time = time.time() + else: + self.__close() + self.connection = None def get_connection(self): + recycle = False if self.connection is None: - self.connection = self.__connect() self.info.clear() - if self.__pool._on_connect: - for l in self.__pool._on_connect: - l.connect(self.connection, self) + self.__connect() elif self.__pool._recycle > -1 and \ time.time() - self.starttime > self.__pool._recycle: - self.__pool.logger.info("Connection %r exceeded timeout; recycling", - self.connection) + self.__pool.logger.info( + "Connection %r exceeded timeout; recycling", + self.connection) + recycle = True + elif self.__pool._invalidate_time > self.starttime: + self.__pool.logger.info( + "Connection %r invalidated due to pool invalidation; " + + "recycling", + self.connection + ) + recycle = True + elif self._soft_invalidate_time > self.starttime: + self.__pool.logger.info( + "Connection %r invalidated due to local soft invalidation; " + + "recycling", + self.connection + ) + recycle = True + + if recycle: self.__close() - self.connection = self.__connect() self.info.clear() - if self.__pool._on_connect: - for l in self.__pool._on_connect: - l.connect(self.connection, self) + + self.__connect() return self.connection def __close(self): - try: - self.__pool.logger.debug("Closing connection %r", self.connection) - self.connection.close() - except (SystemExit, KeyboardInterrupt): - raise - except Exception, e: - self.__pool.logger.debug("Connection %r threw an error on close: %s", - self.connection, e) + self.finalize_callback.clear() + if self.__pool.dispatch.close: + self.__pool.dispatch.close(self.connection, self) + self.__pool._close_connection(self.connection) + self.connection = None - def __connect(self): + def __connect(self, first_connect_check=False): + pool = self.__pool + + # ensure any existing connection is removed, so that if + # creator fails, this attribute stays None + self.connection = None try: self.starttime = time.time() - connection = self.__pool._creator() - self.__pool.logger.debug("Created new connection %r", connection) - return connection - except Exception, e: - self.__pool.logger.debug("Error on connect(): %s", e) + connection = pool._invoke_creator(self) + pool.logger.debug("Created new connection %r", connection) + self.connection = connection + except Exception as e: + pool.logger.debug("Error on connect(): %s", e) raise + else: + if first_connect_check: + pool.dispatch.first_connect.\ + for_modify(pool.dispatch).\ + exec_once(self.connection, self) + if pool.dispatch.connect: + pool.dispatch.connect(self.connection, self) -def _finalize_fairy(connection, connection_record, pool, ref=None): +def _finalize_fairy(connection, connection_record, + pool, ref, echo, fairy=None): + """Cleanup for a :class:`._ConnectionFairy` whether or not it's already + been garbage collected. + + """ _refs.discard(connection_record) - - if ref is not None and (connection_record.fairy is not ref or isinstance(pool, AssertionPool)): + + if ref is not None and \ + connection_record.fairy_ref is not ref: return if connection is not None: + if connection_record and echo: + pool.logger.debug("Connection %r being returned to pool", + connection) + try: - if pool._reset_on_return: - connection.rollback() + fairy = fairy or _ConnectionFairy( + connection, connection_record, echo) + assert fairy.connection is connection + fairy._reset(pool) + # Immediately close detached instances - if connection_record is None: - connection.close() - except Exception, e: - if connection_record is not None: + if not connection_record: + if pool.dispatch.close_detached: + pool.dispatch.close_detached(connection) + pool._close_connection(connection) + except BaseException as e: + pool.logger.error( + "Exception during reset or similar", exc_info=True) + if connection_record: connection_record.invalidate(e=e) - if isinstance(e, (SystemExit, KeyboardInterrupt)): + if not isinstance(e, Exception): raise - - if connection_record is not None: - connection_record.fairy = None - pool.logger.debug("Connection %r being returned to pool", connection) - if pool._on_checkin: - for l in pool._on_checkin: - l.checkin(connection, connection_record) - pool.return_conn(connection_record) + + if connection_record: + connection_record.checkin() + _refs = set() -class _ConnectionFairy(object): - """Proxies a DB-API connection and provides return-on-dereference support.""" - __slots__ = '_pool', '__counter', 'connection', \ - '_connection_record', '__weakref__', '_detached_info' - - def __init__(self, pool): - self._pool = pool - self.__counter = 0 - try: - rec = self._connection_record = pool.get() - conn = self.connection = self._connection_record.get_connection() - rec.fairy = weakref.ref(self, lambda ref:_finalize_fairy(conn, rec, pool, ref)) - _refs.add(rec) - except: - self.connection = None # helps with endless __getattr__ loops later on - self._connection_record = None - raise - self._pool.logger.debug("Connection %r checked out from pool" % - self.connection) +class _ConnectionFairy(object): + + """Proxies a DBAPI connection and provides return-on-dereference + support. + + This is an internal object used by the :class:`.Pool` implementation + to provide context management to a DBAPI connection delivered by + that :class:`.Pool`. + + The name "fairy" is inspired by the fact that the + :class:`._ConnectionFairy` object's lifespan is transitory, as it lasts + only for the length of a specific DBAPI connection being checked out from + the pool, and additionally that as a transparent proxy, it is mostly + invisible. + + .. seealso:: + + :class:`._ConnectionRecord` + + """ + + def __init__(self, dbapi_connection, connection_record, echo): + self.connection = dbapi_connection + self._connection_record = connection_record + self._echo = echo + + connection = None + """A reference to the actual DBAPI connection being tracked.""" + + _connection_record = None + """A reference to the :class:`._ConnectionRecord` object associated + with the DBAPI connection. + + This is currently an internal accessor which is subject to change. + + """ + + _reset_agent = None + """Refer to an object with a ``.commit()`` and ``.rollback()`` method; + if non-None, the "reset-on-return" feature will call upon this object + rather than directly against the dialect-level do_rollback() and + do_commit() methods. + + In practice, a :class:`.Connection` assigns a :class:`.Transaction` object + to this variable when one is in scope so that the :class:`.Transaction` + takes the job of committing or rolling back on return if + :meth:`.Connection.close` is called while the :class:`.Transaction` + still exists. + + This is essentially an "event handler" of sorts but is simplified as an + instance variable both for performance/simplicity as well as that there + can only be one "reset agent" at a time. + """ + + @classmethod + def _checkout(cls, pool, threadconns=None, fairy=None): + if not fairy: + fairy = _ConnectionRecord.checkout(pool) + + fairy._pool = pool + fairy._counter = 0 + + if threadconns is not None: + threadconns.current = weakref.ref(fairy) + + if fairy.connection is None: + raise exc.InvalidRequestError("This connection is closed") + fairy._counter += 1 + + if not pool.dispatch.checkout or fairy._counter != 1: + return fairy + + # Pool listeners can trigger a reconnection on checkout + attempts = 2 + while attempts > 0: + try: + pool.dispatch.checkout(fairy.connection, + fairy._connection_record, + fairy) + return fairy + except exc.DisconnectionError as e: + pool.logger.info( + "Disconnection detected on checkout: %s", e) + fairy._connection_record.invalidate(e) + try: + fairy.connection = \ + fairy._connection_record.get_connection() + except: + with util.safe_reraise(): + fairy._connection_record.checkin() + + attempts -= 1 + + pool.logger.info("Reconnection attempts exhausted on checkout") + fairy.invalidate() + raise exc.InvalidRequestError("This connection is closed") + + def _checkout_existing(self): + return _ConnectionFairy._checkout(self._pool, fairy=self) + + def _checkin(self): + _finalize_fairy(self.connection, self._connection_record, + self._pool, None, self._echo, fairy=self) + self.connection = None + self._connection_record = None + + _close = _checkin + + def _reset(self, pool): + if pool.dispatch.reset: + pool.dispatch.reset(self, self._connection_record) + if pool._reset_on_return is reset_rollback: + if self._echo: + pool.logger.debug("Connection %s rollback-on-return%s", + self.connection, + ", via agent" + if self._reset_agent else "") + if self._reset_agent: + self._reset_agent.rollback() + else: + pool._dialect.do_rollback(self) + elif pool._reset_on_return is reset_commit: + if self._echo: + pool.logger.debug("Connection %s commit-on-return%s", + self.connection, + ", via agent" + if self._reset_agent else "") + if self._reset_agent: + self._reset_agent.commit() + else: + pool._dialect.do_commit(self) @property def _logger(self): @@ -332,74 +844,90 @@ class _ConnectionFairy(object): @property def is_valid(self): + """Return True if this :class:`._ConnectionFairy` still refers + to an active DBAPI connection.""" + return self.connection is not None - @property + @util.memoized_property def info(self): - """An info collection unique to this DB-API connection.""" + """Info dictionary associated with the underlying DBAPI connection + referred to by this :class:`.ConnectionFairy`, allowing user-defined + data to be associated with the connection. - try: - return self._connection_record.info - except AttributeError: - if self.connection is None: - raise exc.InvalidRequestError("This connection is closed") - try: - return self._detached_info - except AttributeError: - self._detached_info = value = {} - return value + The data here will follow along with the DBAPI connection including + after it is returned to the connection pool and used again + in subsequent instances of :class:`._ConnectionFairy`. It is shared + with the :attr:`._ConnectionRecord.info` and :attr:`.Connection.info` + accessors. - def invalidate(self, e=None): + The dictionary associated with a particular DBAPI connection is + discarded when the connection itself is discarded. + + """ + return self._connection_record.info + + @property + def record_info(self): + """Info dictionary associated with the :class:`._ConnectionRecord + container referred to by this :class:`.ConnectionFairy`. + + Unlike the :attr:`._ConnectionFairy.info` dictionary, the lifespan + of this dictionary is persistent across connections that are + disconnected and/or invalidated within the lifespan of a + :class:`._ConnectionRecord`. + + .. versionadded:: 1.1 + + """ + if self._connection_record: + return self._connection_record.record_info + else: + return None + + def invalidate(self, e=None, soft=False): """Mark this connection as invalidated. - The connection will be immediately closed. The containing - ConnectionRecord will create a new connection when next used. + This method can be called directly, and is also called as a result + of the :meth:`.Connection.invalidate` method. When invoked, + the DBAPI connection is immediately closed and discarded from + further use by the pool. The invalidation mechanism proceeds + via the :meth:`._ConnectionRecord.invalidate` internal method. + + :param e: an exception object indicating a reason for the invalidation. + + :param soft: if True, the connection isn't closed; instead, this + connection will be recycled on next checkout. + + .. versionadded:: 1.0.3 + + .. seealso:: + + :ref:`pool_connection_invalidation` + """ if self.connection is None: - raise exc.InvalidRequestError("This connection is closed") - if self._connection_record is not None: - self._connection_record.invalidate(e=e) - self.connection = None - self._close() + util.warn("Can't invalidate an already-closed connection.") + return + if self._connection_record: + self._connection_record.invalidate(e=e, soft=soft) + if not soft: + self.connection = None + self._checkin() def cursor(self, *args, **kwargs): - try: - c = self.connection.cursor(*args, **kwargs) - return _CursorFairy(self, c) - except Exception, e: - self.invalidate(e=e) - raise + """Return a new DBAPI cursor for the underlying connection. + + This method is a proxy for the ``connection.cursor()`` DBAPI + method. + + """ + return self.connection.cursor(*args, **kwargs) def __getattr__(self, key): return getattr(self.connection, key) - def checkout(self): - if self.connection is None: - raise exc.InvalidRequestError("This connection is closed") - self.__counter += 1 - - if not self._pool._on_checkout or self.__counter != 1: - return self - - # Pool listeners can trigger a reconnection on checkout - attempts = 2 - while attempts > 0: - try: - for l in self._pool._on_checkout: - l.checkout(self.connection, self._connection_record, self) - return self - except exc.DisconnectionError, e: - self._pool.logger.info( - "Disconnection detected on checkout: %s", e) - self._connection_record.invalidate(e) - self.connection = self._connection_record.get_connection() - attempts -= 1 - - self._pool.logger.info("Reconnection attempts exhausted on checkout") - self.invalidate() - raise exc.InvalidRequestError("This connection is closed") - def detach(self): """Separate this connection from its Pool. @@ -414,75 +942,53 @@ class _ConnectionFairy(object): """ if self._connection_record is not None: - _refs.remove(self._connection_record) - self._connection_record.fairy = None - self._connection_record.connection = None - self._pool.do_return_conn(self._connection_record) - self._detached_info = \ - self._connection_record.info.copy() + rec = self._connection_record + _refs.remove(rec) + rec.fairy_ref = None + rec.connection = None + # TODO: should this be _return_conn? + self._pool._do_return_conn(self._connection_record) + self.info = self.info.copy() self._connection_record = None + if self._pool.dispatch.detach: + self._pool.dispatch.detach(self.connection, rec) + def close(self): - self.__counter -= 1 - if self.__counter == 0: - self._close() + self._counter -= 1 + if self._counter == 0: + self._checkin() - def _close(self): - _finalize_fairy(self.connection, self._connection_record, self._pool) - self.connection = None - self._connection_record = None - -class _CursorFairy(object): - __slots__ = '_parent', 'cursor', 'execute' - - def __init__(self, parent, cursor): - self._parent = parent - self.cursor = cursor - self.execute = cursor.execute - - def invalidate(self, e=None): - self._parent.invalidate(e=e) - - def __iter__(self): - return iter(self.cursor) - - def close(self): - try: - self.cursor.close() - except Exception, e: - try: - ex_text = str(e) - except TypeError: - ex_text = repr(e) - self.__parent._logger.warn("Error closing cursor: %s", ex_text) - - if isinstance(e, (SystemExit, KeyboardInterrupt)): - raise - - def __setattr__(self, key, value): - if key in self.__slots__: - object.__setattr__(self, key, value) - else: - setattr(self.cursor, key, value) - - def __getattr__(self, key): - return getattr(self.cursor, key) class SingletonThreadPool(Pool): + """A Pool that maintains one connection per thread. Maintains one connection per each thread, never moving a connection to a thread other than the one which it was created in. - This is used for SQLite, which both does not handle multithreading by - default, and also requires a singleton connection if a :memory: database - is being used. + .. warning:: the :class:`.SingletonThreadPool` will call ``.close()`` + on arbitrary connections that exist beyond the size setting of + ``pool_size``, e.g. if more unique **thread identities** + than what ``pool_size`` states are used. This cleanup is + non-deterministic and not sensitive to whether or not the connections + linked to those thread identities are currently in use. - Options are the same as those of :class:`Pool`, as well as: + :class:`.SingletonThreadPool` may be improved in a future release, + however in its current status it is generally used only for test + scenarios using a SQLite ``:memory:`` database and is not recommended + for production use. - :param pool_size: The number of threads in which to maintain connections + + Options are the same as those of :class:`.Pool`, as well as: + + :param pool_size: The number of threads in which to maintain connections at once. Defaults to five. - + + :class:`.SingletonThreadPool` is used by the SQLite dialect + automatically when a memory-based database is used. + See :ref:`sqlite_toplevel`. + """ def __init__(self, creator, pool_size=5, **kw): @@ -494,12 +1000,15 @@ class SingletonThreadPool(Pool): def recreate(self): self.logger.info("Pool recreating") - return SingletonThreadPool(self._creator, - pool_size=self.size, - recycle=self._recycle, - echo=self.echo, - use_threadlocal=self._use_threadlocal, - listeners=self.listeners) + return self.__class__(self._creator, + pool_size=self.size, + recycle=self._recycle, + echo=self.echo, + logging_name=self._orig_logging_name, + use_threadlocal=self._use_threadlocal, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + dialect=self._dialect) def dispose(self): """Dispose of this pool.""" @@ -507,62 +1016,65 @@ class SingletonThreadPool(Pool): for conn in self._all_conns: try: conn.close() - except (SystemExit, KeyboardInterrupt): - raise - except: + except Exception: # pysqlite won't even let you close a conn from a thread # that didn't create it pass - - self._all_conns.clear() - - def dispose_local(self): - if hasattr(self._conn, 'current'): - conn = self._conn.current() - self._all_conns.discard(conn) - del self._conn.current - def cleanup(self): - while len(self._all_conns) > self.size: - self._all_conns.pop() + self._all_conns.clear() + + def _cleanup(self): + while len(self._all_conns) >= self.size: + c = self._all_conns.pop() + c.close() def status(self): - return "SingletonThreadPool id:%d size: %d" % (id(self), len(self._all_conns)) + return "SingletonThreadPool id:%d size: %d" % \ + (id(self), len(self._all_conns)) - def do_return_conn(self, conn): + def _do_return_conn(self, conn): pass - def do_get(self): + def _do_get(self): try: c = self._conn.current() if c: return c except AttributeError: pass - c = self.create_connection() + c = self._create_connection() self._conn.current = weakref.ref(c) + if len(self._all_conns) >= self.size: + self._cleanup() self._all_conns.add(c) - if len(self._all_conns) > self.size: - self.cleanup() return c + class QueuePool(Pool): - """A Pool that imposes a limit on the number of open connections.""" + + """A :class:`.Pool` that imposes a limit on the number of open connections. + + :class:`.QueuePool` is the default pooling implementation used for + all :class:`.Engine` objects, unless the SQLite dialect is in use. + + """ def __init__(self, creator, pool_size=5, max_overflow=10, timeout=30, **kw): - """ + r""" Construct a QueuePool. :param creator: a callable function that returns a DB-API - connection object. The function will be called with - parameters. + connection object, same as that of :paramref:`.Pool.creator`. - :param pool_size: The size of the pool to be maintained. This - is the largest number of connections that will be kept - persistently in the pool. Note that the pool begins with no - connections; once this number of connections is requested, - that number of connections will remain. Defaults to 5. + :param pool_size: The size of the pool to be maintained, + defaults to 5. This is the largest number of connections that + will be kept persistently in the pool. Note that the pool + begins with no connections; once this number of connections + is requested, that number of connections will remain. + ``pool_size`` can be set to 0 to indicate no size limit; to + disable pooling, use a :class:`~sqlalchemy.pool.NullPool` + instead. :param max_overflow: The maximum overflow size of the pool. When the number of checked-out connections reaches the @@ -580,36 +1092,10 @@ class QueuePool(Pool): :param timeout: The number of seconds to wait before giving up on returning a connection. Defaults to 30. - :param recycle: If set to non -1, number of seconds between - connection recycling, which means upon checkout, if this - timeout is surpassed the connection will be closed and - replaced with a newly opened connection. Defaults to -1. - - :param echo: If True, connections being pulled and retrieved - from the pool will be logged to the standard output, as well - as pool sizing information. Echoing can also be achieved by - enabling logging for the "sqlalchemy.pool" - namespace. Defaults to False. - - :param use_threadlocal: If set to True, repeated calls to - :meth:`connect` within the same application thread will be - guaranteed to return the same connection object, if one has - already been retrieved from the pool and has not been - returned yet. Offers a slight performance advantage at the - cost of individual transactions by default. The - :meth:`unique_connection` method is provided to bypass the - threadlocal behavior installed into :meth:`connect`. - - :param reset_on_return: If true, reset the database state of - connections returned to the pool. This is typically a - ROLLBACK to release locks and transaction resources. - Disable at your own peril. Defaults to True. - - :param listeners: A list of - :class:`~sqlalchemy.interfaces.PoolListener`-like objects or - dictionaries of callables that receive events when DB-API - connections are created, checked out and checked in to the - pool. + :param \**kw: Other keyword arguments including + :paramref:`.Pool.recycle`, :paramref:`.Pool.echo`, + :paramref:`.Pool.reset_on_return` and others are passed to the + :class:`.Pool` constructor. """ Pool.__init__(self, creator, **kw) @@ -617,57 +1103,72 @@ class QueuePool(Pool): self._overflow = 0 - pool_size self._max_overflow = max_overflow self._timeout = timeout - self._overflow_lock = self._max_overflow > -1 and threading.Lock() or None + self._overflow_lock = threading.Lock() - def recreate(self): - self.logger.info("Pool recreating") - return QueuePool(self._creator, pool_size=self._pool.maxsize, - max_overflow=self._max_overflow, timeout=self._timeout, - recycle=self._recycle, echo=self.echo, - use_threadlocal=self._use_threadlocal, listeners=self.listeners) - - def do_return_conn(self, conn): + def _do_return_conn(self, conn): try: self._pool.put(conn, False) except sqla_queue.Full: - if self._overflow_lock is None: - self._overflow -= 1 - else: - self._overflow_lock.acquire() - try: - self._overflow -= 1 - finally: - self._overflow_lock.release() + try: + conn.close() + finally: + self._dec_overflow() + + def _do_get(self): + use_overflow = self._max_overflow > -1 - def do_get(self): try: - wait = self._max_overflow > -1 and self._overflow >= self._max_overflow + wait = use_overflow and self._overflow >= self._max_overflow return self._pool.get(wait, self._timeout) except sqla_queue.Empty: - if self._max_overflow > -1 and self._overflow >= self._max_overflow: + if use_overflow and self._overflow >= self._max_overflow: if not wait: - return self.do_get() + return self._do_get() else: raise exc.TimeoutError( - "QueuePool limit of size %d overflow %d reached, " - "connection timed out, timeout %d" % - (self.size(), self.overflow(), self._timeout)) + "QueuePool limit of size %d overflow %d reached, " + "connection timed out, timeout %d" % + (self.size(), self.overflow(), self._timeout)) - if self._overflow_lock is not None: - self._overflow_lock.acquire() + if self._inc_overflow(): + try: + return self._create_connection() + except: + with util.safe_reraise(): + self._dec_overflow() + else: + return self._do_get() - if self._max_overflow > -1 and self._overflow >= self._max_overflow: - if self._overflow_lock is not None: - self._overflow_lock.release() - return self.do_get() - - try: - con = self.create_connection() + def _inc_overflow(self): + if self._max_overflow == -1: + self._overflow += 1 + return True + with self._overflow_lock: + if self._overflow < self._max_overflow: self._overflow += 1 - finally: - if self._overflow_lock is not None: - self._overflow_lock.release() - return con + return True + else: + return False + + def _dec_overflow(self): + if self._max_overflow == -1: + self._overflow -= 1 + return True + with self._overflow_lock: + self._overflow -= 1 + return True + + def recreate(self): + self.logger.info("Pool recreating") + return self.__class__(self._creator, pool_size=self._pool.maxsize, + max_overflow=self._max_overflow, + timeout=self._timeout, + recycle=self._recycle, echo=self.echo, + logging_name=self._orig_logging_name, + use_threadlocal=self._use_threadlocal, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + dialect=self._dialect) def dispose(self): while True: @@ -682,11 +1183,11 @@ class QueuePool(Pool): def status(self): return "Pool size: %d Connections in pool: %d "\ - "Current Overflow: %d Current Checked out "\ - "connections: %d" % (self.size(), - self.checkedin(), - self.overflow(), - self.checkedout()) + "Current Overflow: %d Current Checked out "\ + "connections: %d" % (self.size(), + self.checkedin(), + self.overflow(), + self.checkedout()) def size(self): return self._pool.maxsize @@ -700,7 +1201,9 @@ class QueuePool(Pool): def checkedout(self): return self._pool.maxsize - self._pool.qsize() + self._overflow + class NullPool(Pool): + """A Pool which does not pool connections. Instead it literally opens and closes the underlying DB-API connection @@ -710,34 +1213,39 @@ class NullPool(Pool): invalidation are not supported by this Pool implementation, since no connections are held persistently. + .. versionchanged:: 0.7 + :class:`.NullPool` is used by the SQlite dialect automatically + when a file-based database is used. See :ref:`sqlite_toplevel`. + """ def status(self): return "NullPool" - def do_return_conn(self, conn): + def _do_return_conn(self, conn): conn.close() - def do_return_invalid(self, conn): - pass - - def do_get(self): - return self.create_connection() + def _do_get(self): + return self._create_connection() def recreate(self): self.logger.info("Pool recreating") - return NullPool(self._creator, - recycle=self._recycle, - echo=self.echo, - use_threadlocal=self._use_threadlocal, - listeners=self.listeners) + return self.__class__(self._creator, + recycle=self._recycle, + echo=self.echo, + logging_name=self._orig_logging_name, + use_threadlocal=self._use_threadlocal, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + dialect=self._dialect) def dispose(self): pass class StaticPool(Pool): + """A Pool of exactly one connection, used for all requests. Reconnect-related functions such as ``recycle`` and connection @@ -754,7 +1262,7 @@ class StaticPool(Pool): @memoized_property def connection(self): return _ConnectionRecord(self) - + def status(self): return "StaticPool" @@ -770,47 +1278,52 @@ class StaticPool(Pool): use_threadlocal=self._use_threadlocal, reset_on_return=self._reset_on_return, echo=self.echo, - listeners=self.listeners) + logging_name=self._orig_logging_name, + _dispatch=self.dispatch, + dialect=self._dialect) - def create_connection(self): + def _create_connection(self): return self._conn - def do_return_conn(self, conn): + def _do_return_conn(self, conn): pass - def do_return_invalid(self, conn): - pass - - def do_get(self): + def _do_get(self): return self.connection + class AssertionPool(Pool): - """A Pool that allows at most one checked out connection at any given time. + + """A :class:`.Pool` that allows at most one checked out connection at + any given time. This will raise an exception if more than one connection is checked out at a time. Useful for debugging code that is using more connections than desired. + .. versionchanged:: 0.7 + :class:`.AssertionPool` also logs a traceback of where + the original connection was checked out, and reports + this in the assertion error raised. + """ def __init__(self, *args, **kw): self._conn = None self._checked_out = False + self._store_traceback = kw.pop('store_traceback', True) + self._checkout_traceback = None Pool.__init__(self, *args, **kw) - + def status(self): return "AssertionPool" - def do_return_conn(self, conn): + def _do_return_conn(self, conn): if not self._checked_out: raise AssertionError("connection is not checked out") self._checked_out = False assert conn is self._conn - def do_return_invalid(self, conn): - self._conn = None - self._checked_out = False - def dispose(self): self._checked_out = False if self._conn: @@ -818,20 +1331,31 @@ class AssertionPool(Pool): def recreate(self): self.logger.info("Pool recreating") - return AssertionPool(self._creator, echo=self.echo, - listeners=self.listeners) - - def do_get(self): + return self.__class__(self._creator, echo=self.echo, + logging_name=self._orig_logging_name, + _dispatch=self.dispatch, + dialect=self._dialect) + + def _do_get(self): if self._checked_out: - raise AssertionError("connection is already checked out") - + if self._checkout_traceback: + suffix = ' at:\n%s' % ''.join( + chop_traceback(self._checkout_traceback)) + else: + suffix = '' + raise AssertionError("connection is already checked out" + suffix) + if not self._conn: - self._conn = self.create_connection() - + self._conn = self._create_connection() + self._checked_out = True + if self._store_traceback: + self._checkout_traceback = traceback.format_stack() return self._conn + class _DBProxy(object): + """Layers connection pooling behavior on top of a standard DB-API module. Proxies a DB-API 2.0 connect() call to a connection pool keyed to the @@ -849,7 +1373,7 @@ class _DBProxy(object): a Pool class, defaulting to QueuePool Other parameters are sent to the Pool object's constructor. - + """ self.module = module @@ -857,9 +1381,9 @@ class _DBProxy(object): self.poolclass = poolclass self.pools = {} self._create_pool_mutex = threading.Lock() - + def close(self): - for key in self.pools.keys(): + for key in list(self.pools): del self.pools[key] def __del__(self): @@ -876,14 +1400,16 @@ class _DBProxy(object): self._create_pool_mutex.acquire() try: if key not in self.pools: - pool = self.poolclass(lambda: self.module.connect(*args, **kw), **self.kw) + kw.pop('sa_pool_key', None) + pool = self.poolclass( + lambda: self.module.connect(*args, **kw), **self.kw) self.pools[key] = pool return pool else: return self.pools[key] finally: self._create_pool_mutex.release() - + def connect(self, *args, **kw): """Activate a connection to the database. @@ -895,7 +1421,7 @@ class _DBProxy(object): If the pool has no available connections and allows new connections to be created, a new database connection will be made. - + """ return self.get_pool(*args, **kw).connect() @@ -910,4 +1436,10 @@ class _DBProxy(object): pass def _serialize(self, *args, **kw): - return pickle.dumps([args, kw]) + if "sa_pool_key" in kw: + return kw['sa_pool_key'] + + return tuple( + list(args) + + [(k, kw[k]) for k in sorted(kw)] + ) diff --git a/sqlalchemy/processors.py b/sqlalchemy/processors.py index c99ca4c..17f7ecc 100644 --- a/sqlalchemy/processors.py +++ b/sqlalchemy/processors.py @@ -1,10 +1,12 @@ -# processors.py +# sqlalchemy/processors.py +# Copyright (C) 2010-2017 the SQLAlchemy authors and contributors +# # Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""defines generic type conversion functions, as used in bind and result +"""defines generic type conversion functions, as used in bind and result processors. They all share one common characteristic: None is passed through unchanged. @@ -14,41 +16,47 @@ They all share one common characteristic: None is passed through unchanged. import codecs import re import datetime +from . import util + def str_to_datetime_processor_factory(regexp, type_): rmatch = regexp.match # Even on python2.6 datetime.strptime is both slower than this code # and it does not support microseconds. + has_named_groups = bool(regexp.groupindex) + def process(value): if value is None: return None else: - return type_(*map(int, rmatch(value).groups(0))) + try: + m = rmatch(value) + except TypeError: + raise ValueError("Couldn't parse %s string '%r' " + "- value is not a string." % + (type_.__name__, value)) + if m is None: + raise ValueError("Couldn't parse %s string: " + "'%s'" % (type_.__name__, value)) + if has_named_groups: + groups = m.groupdict(0) + return type_(**dict(list(zip( + iter(groups.keys()), + list(map(int, iter(groups.values()))) + )))) + else: + return type_(*list(map(int, m.groups(0)))) return process -try: - from sqlalchemy.cprocessors import UnicodeResultProcessor, \ - DecimalResultProcessor, \ - to_float, to_str, int_to_boolean, \ - str_to_datetime, str_to_time, \ - str_to_date - def to_unicode_processor_factory(encoding, errors=None): - # this is cumbersome but it would be even more so on the C side - if errors is not None: - return UnicodeResultProcessor(encoding, errors).process - else: - return UnicodeResultProcessor(encoding).process - - def to_decimal_processor_factory(target_class, scale=10): - # Note that the scale argument is not taken into account for integer - # values in the C implementation while it is in the Python one. - # For example, the Python implementation might return - # Decimal('5.00000') whereas the C implementation will - # return Decimal('5'). These are equivalent of course. - return DecimalResultProcessor(target_class, "%%.%df" % scale).process +def boolean_to_int(value): + if value is None: + return None + else: + return int(bool(value)) -except ImportError: + +def py_fallback(): def to_unicode_processor_factory(encoding, errors=None): decoder = codecs.getdecoder(encoding) @@ -62,7 +70,22 @@ except ImportError: return decoder(value, errors)[0] return process - def to_decimal_processor_factory(target_class, scale=10): + def to_conditional_unicode_processor_factory(encoding, errors=None): + decoder = codecs.getdecoder(encoding) + + def process(value): + if value is None: + return None + elif isinstance(value, util.text_type): + return value + else: + # decoder returns a tuple: (value, len). Simply dropping the + # len part is safe: it is done that way in the normal + # 'xx'.decode(encoding) code path. + return decoder(value, errors)[0] + return process + + def to_decimal_processor_factory(target_class, scale): fstring = "%%.%df" % scale def process(value): @@ -88,14 +111,45 @@ except ImportError: if value is None: return None else: - return value and True or False + return bool(value) - DATETIME_RE = re.compile("(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?") - TIME_RE = re.compile("(\d+):(\d+):(\d+)(?:\.(\d+))?") - DATE_RE = re.compile("(\d+)-(\d+)-(\d+)") + DATETIME_RE = re.compile( + r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?") + TIME_RE = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") + DATE_RE = re.compile(r"(\d+)-(\d+)-(\d+)") str_to_datetime = str_to_datetime_processor_factory(DATETIME_RE, datetime.datetime) str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time) str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date) + return locals() +try: + from sqlalchemy.cprocessors import UnicodeResultProcessor, \ + DecimalResultProcessor, \ + to_float, to_str, int_to_boolean, \ + str_to_datetime, str_to_time, \ + str_to_date + + def to_unicode_processor_factory(encoding, errors=None): + if errors is not None: + return UnicodeResultProcessor(encoding, errors).process + else: + return UnicodeResultProcessor(encoding).process + + def to_conditional_unicode_processor_factory(encoding, errors=None): + if errors is not None: + return UnicodeResultProcessor(encoding, errors).conditional_process + else: + return UnicodeResultProcessor(encoding).conditional_process + + def to_decimal_processor_factory(target_class, scale): + # Note that the scale argument is not taken into account for integer + # values in the C implementation while it is in the Python one. + # For example, the Python implementation might return + # Decimal('5.00000') whereas the C implementation will + # return Decimal('5'). These are equivalent of course. + return DecimalResultProcessor(target_class, "%%.%df" % scale).process + +except ImportError: + globals().update(py_fallback()) diff --git a/sqlalchemy/schema.py b/sqlalchemy/schema.py index 8ffb68a..9924a67 100644 --- a/sqlalchemy/schema.py +++ b/sqlalchemy/schema.py @@ -1,2386 +1,66 @@ # schema.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""The schema module provides the building blocks for database metadata. - -Each element within this module describes a database entity which can be -created and dropped, or is otherwise part of such an entity. Examples include -tables, columns, sequences, and indexes. - -All entities are subclasses of :class:`~sqlalchemy.schema.SchemaItem`, and as defined -in this module they are intended to be agnostic of any vendor-specific -constructs. - -A collection of entities are grouped into a unit called -:class:`~sqlalchemy.schema.MetaData`. MetaData serves as a logical grouping of schema -elements, and can also be associated with an actual database connection such -that operations involving the contained elements can contact the database as -needed. - -Two of the elements here also build upon their "syntactic" counterparts, which -are defined in :class:`~sqlalchemy.sql.expression.`, specifically -:class:`~sqlalchemy.schema.Table` and :class:`~sqlalchemy.schema.Column`. Since these objects -are part of the SQL expression language, they are usable as components in SQL -expressions. +"""Compatibility namespace for sqlalchemy.sql.schema and related. """ -import re, inspect -from sqlalchemy import exc, util, dialects -from sqlalchemy.sql import expression, visitors - -URL = None - -__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', - 'ForeignKeyConstraint', 'PrimaryKeyConstraint', 'CheckConstraint', - 'UniqueConstraint', 'DefaultGenerator', 'Constraint', 'MetaData', - 'ThreadLocalMetaData', 'SchemaVisitor', 'PassiveDefault', - 'DefaultClause', 'FetchedValue', 'ColumnDefault', 'DDL', - 'CreateTable', 'DropTable', 'CreateSequence', 'DropSequence', - 'AddConstraint', 'DropConstraint', - ] -__all__.sort() - -RETAIN_SCHEMA = util.symbol('retain_schema') - -class SchemaItem(visitors.Visitable): - """Base class for items that define a database schema.""" - - __visit_name__ = 'schema_item' - quote = None - - def _init_items(self, *args): - """Initialize the list of child items for this SchemaItem.""" - - for item in args: - if item is not None: - item._set_parent(self) - - def _set_parent(self, parent): - """Associate with this SchemaItem's parent object.""" - - raise NotImplementedError() - - def get_children(self, **kwargs): - """used to allow SchemaVisitor access""" - return [] - - def __repr__(self): - return "%s()" % self.__class__.__name__ - - @util.memoized_property - def info(self): - return {} - -def _get_table_key(name, schema): - if schema is None: - return name - else: - return schema + "." + name - -class Table(SchemaItem, expression.TableClause): - """Represent a table in a database. - - e.g.:: - - mytable = Table("mytable", metadata, - Column('mytable_id', Integer, primary_key=True), - Column('value', String(50)) - ) - - The Table object constructs a unique instance of itself based on its - name within the given MetaData object. Constructor - arguments are as follows: - - :param name: The name of this table as represented in the database. - - This property, along with the *schema*, indicates the *singleton - identity* of this table in relation to its parent :class:`MetaData`. - Additional calls to :class:`Table` with the same name, metadata, - and schema name will return the same :class:`Table` object. - - Names which contain no upper case characters - will be treated as case insensitive names, and will not be quoted - unless they are a reserved word. Names with any number of upper - case characters will be quoted and sent exactly. Note that this - behavior applies even for databases which standardize upper - case names as case insensitive such as Oracle. - - :param metadata: a :class:`MetaData` object which will contain this - table. The metadata is used as a point of association of this table - with other tables which are referenced via foreign key. It also - may be used to associate this table with a particular - :class:`~sqlalchemy.engine.base.Connectable`. - - :param \*args: Additional positional arguments are used primarily - to add the list of :class:`Column` objects contained within this table. - Similar to the style of a CREATE TABLE statement, other :class:`SchemaItem` - constructs may be added here, including :class:`PrimaryKeyConstraint`, - and :class:`ForeignKeyConstraint`. - - :param autoload: Defaults to False: the Columns for this table should be reflected - from the database. Usually there will be no Column objects in the - constructor if this property is set. - - :param autoload_with: If autoload==True, this is an optional Engine or Connection - instance to be used for the table reflection. If ``None``, the - underlying MetaData's bound connectable will be used. - - :param implicit_returning: True by default - indicates that - RETURNING can be used by default to fetch newly inserted primary key - values, for backends which support this. Note that - create_engine() also provides an implicit_returning flag. - - :param include_columns: A list of strings indicating a subset of columns to be loaded via - the ``autoload`` operation; table columns who aren't present in - this list will not be represented on the resulting ``Table`` - object. Defaults to ``None`` which indicates all columns should - be reflected. - - :param info: A dictionary which defaults to ``{}``. A space to store application - specific data. This must be a dictionary. - - :param mustexist: When ``True``, indicates that this Table must already - be present in the given :class:`MetaData`` collection. - - :param prefixes: - A list of strings to insert after CREATE in the CREATE TABLE - statement. They will be separated by spaces. - - :param quote: Force quoting of this table's name on or off, corresponding - to ``True`` or ``False``. When left at its default of ``None``, - the column identifier will be quoted according to whether the name is - case sensitive (identifiers with at least one upper case character are - treated as case sensitive), or if it's a reserved word. This flag - is only needed to force quoting of a reserved word which is not known - by the SQLAlchemy dialect. - - :param quote_schema: same as 'quote' but applies to the schema identifier. - - :param schema: The *schema name* for this table, which is required if the table - resides in a schema other than the default selected schema for the - engine's database connection. Defaults to ``None``. - - :param useexisting: When ``True``, indicates that if this Table is already - present in the given :class:`MetaData`, apply further arguments within - the constructor to the existing :class:`Table`. If this flag is not - set, an error is raised when the parameters of an existing :class:`Table` - are overwritten. - - """ - - __visit_name__ = 'table' - - ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop') - - def __new__(cls, *args, **kw): - if not args: - # python3k pickle seems to call this - return object.__new__(cls) - - try: - name, metadata, args = args[0], args[1], args[2:] - except IndexError: - raise TypeError("Table() takes at least two arguments") - - schema = kw.get('schema', None) - useexisting = kw.pop('useexisting', False) - mustexist = kw.pop('mustexist', False) - key = _get_table_key(name, schema) - if key in metadata.tables: - if not useexisting and bool(args): - raise exc.InvalidRequestError( - "Table '%s' is already defined for this MetaData instance. " - "Specify 'useexisting=True' to redefine options and " - "columns on an existing Table object." % key) - table = metadata.tables[key] - table._init_existing(*args, **kw) - return table - else: - if mustexist: - raise exc.InvalidRequestError( - "Table '%s' not defined" % (key)) - metadata.tables[key] = table = object.__new__(cls) - try: - table._init(name, metadata, *args, **kw) - return table - except: - metadata.tables.pop(key) - raise - - def __init__(self, *args, **kw): - # __init__ is overridden to prevent __new__ from - # calling the superclass constructor. - pass - - def _init(self, name, metadata, *args, **kwargs): - super(Table, self).__init__(name) - self.metadata = metadata - self.schema = kwargs.pop('schema', None) - self.indexes = set() - self.constraints = set() - self._columns = expression.ColumnCollection() - self._set_primary_key(PrimaryKeyConstraint()) - self._foreign_keys = util.OrderedSet() - self.ddl_listeners = util.defaultdict(list) - self.kwargs = {} - if self.schema is not None: - self.fullname = "%s.%s" % (self.schema, self.name) - else: - self.fullname = self.name - - autoload = kwargs.pop('autoload', False) - autoload_with = kwargs.pop('autoload_with', None) - include_columns = kwargs.pop('include_columns', None) - - self.implicit_returning = kwargs.pop('implicit_returning', True) - self.quote = kwargs.pop('quote', None) - self.quote_schema = kwargs.pop('quote_schema', None) - if 'info' in kwargs: - self.info = kwargs.pop('info') - - self._prefixes = kwargs.pop('prefixes', []) - - self._extra_kwargs(**kwargs) - - # load column definitions from the database if 'autoload' is defined - # we do it after the table is in the singleton dictionary to support - # circular foreign keys - if autoload: - if autoload_with: - autoload_with.reflecttable(self, include_columns=include_columns) - else: - _bind_or_error(metadata, msg="No engine is bound to this Table's MetaData. " - "Pass an engine to the Table via " - "autoload_with=, " - "or associate the MetaData with an engine via " - "metadata.bind=").\ - reflecttable(self, include_columns=include_columns) - - # initialize all the column, etc. objects. done after reflection to - # allow user-overrides - self._init_items(*args) - - def _init_existing(self, *args, **kwargs): - autoload = kwargs.pop('autoload', False) - autoload_with = kwargs.pop('autoload_with', None) - schema = kwargs.pop('schema', None) - if schema and schema != self.schema: - raise exc.ArgumentError( - "Can't change schema of existing table from '%s' to '%s'", - (self.schema, schema)) - - include_columns = kwargs.pop('include_columns', None) - if include_columns: - for c in self.c: - if c.name not in include_columns: - self.c.remove(c) - - for key in ('quote', 'quote_schema'): - if key in kwargs: - setattr(self, key, kwargs.pop(key)) - - if 'info' in kwargs: - self.info = kwargs.pop('info') - - self._extra_kwargs(**kwargs) - self._init_items(*args) - - def _extra_kwargs(self, **kwargs): - # validate remaining kwargs that they all specify DB prefixes - if len([k for k in kwargs - if not re.match(r'^(?:%s)_' % '|'.join(dialects.__all__), k)]): - raise TypeError( - "Invalid argument(s) for Table: %r" % kwargs.keys()) - self.kwargs.update(kwargs) - - def _set_primary_key(self, pk): - if getattr(self, '_primary_key', None) in self.constraints: - self.constraints.remove(self._primary_key) - self._primary_key = pk - self.constraints.add(pk) - - for c in pk.columns: - c.primary_key = True - - @util.memoized_property - def _autoincrement_column(self): - for col in self.primary_key: - if col.autoincrement and \ - isinstance(col.type, types.Integer) and \ - not col.foreign_keys and \ - isinstance(col.default, (type(None), Sequence)): - - return col - - @property - def key(self): - return _get_table_key(self.name, self.schema) - - @property - def primary_key(self): - return self._primary_key - - def __repr__(self): - return "Table(%s)" % ', '.join( - [repr(self.name)] + [repr(self.metadata)] + - [repr(x) for x in self.columns] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']]) - - def __str__(self): - return _get_table_key(self.description, self.schema) - - @property - def bind(self): - """Return the connectable associated with this Table.""" - - return self.metadata and self.metadata.bind or None - - def append_column(self, column): - """Append a ``Column`` to this ``Table``.""" - - column._set_parent(self) - - def append_constraint(self, constraint): - """Append a ``Constraint`` to this ``Table``.""" - - constraint._set_parent(self) - - def append_ddl_listener(self, event, listener): - """Append a DDL event listener to this ``Table``. - - The ``listener`` callable will be triggered when this ``Table`` is - created or dropped, either directly before or after the DDL is issued - to the database. The listener may modify the Table, but may not abort - the event itself. - - Arguments are: - - event - One of ``Table.ddl_events``; e.g. 'before-create', 'after-create', - 'before-drop' or 'after-drop'. - - listener - A callable, invoked with three positional arguments: - - event - The event currently being handled - target - The ``Table`` object being created or dropped - bind - The ``Connection`` bueing used for DDL execution. - - Listeners are added to the Table's ``ddl_listeners`` attribute. - """ - - if event not in self.ddl_events: - raise LookupError(event) - self.ddl_listeners[event].append(listener) - - def _set_parent(self, metadata): - metadata.tables[_get_table_key(self.name, self.schema)] = self - self.metadata = metadata - - def get_children(self, column_collections=True, schema_visitor=False, **kwargs): - if not schema_visitor: - return expression.TableClause.get_children( - self, column_collections=column_collections, **kwargs) - else: - if column_collections: - return list(self.columns) - else: - return [] - - def exists(self, bind=None): - """Return True if this table exists.""" - - if bind is None: - bind = _bind_or_error(self) - - return bind.run_callable(bind.dialect.has_table, self.name, schema=self.schema) - - def create(self, bind=None, checkfirst=False): - """Issue a ``CREATE`` statement for this table. - - See also ``metadata.create_all()``. - """ - self.metadata.create_all(bind=bind, checkfirst=checkfirst, tables=[self]) - - def drop(self, bind=None, checkfirst=False): - """Issue a ``DROP`` statement for this table. - - See also ``metadata.drop_all()``. - """ - self.metadata.drop_all(bind=bind, checkfirst=checkfirst, tables=[self]) - - def tometadata(self, metadata, schema=RETAIN_SCHEMA): - """Return a copy of this ``Table`` associated with a different ``MetaData``.""" - - try: - if schema is RETAIN_SCHEMA: - schema = self.schema - key = _get_table_key(self.name, schema) - return metadata.tables[key] - except KeyError: - args = [] - for c in self.columns: - args.append(c.copy(schema=schema)) - for c in self.constraints: - args.append(c.copy(schema=schema)) - return Table(self.name, metadata, schema=schema, *args) - -class Column(SchemaItem, expression.ColumnClause): - """Represents a column in a database table.""" - - __visit_name__ = 'column' - - def __init__(self, *args, **kwargs): - """ - Construct a new ``Column`` object. - - :param name: The name of this column as represented in the database. - This argument may be the first positional argument, or specified - via keyword. - - Names which contain no upper case characters - will be treated as case insensitive names, and will not be quoted - unless they are a reserved word. Names with any number of upper - case characters will be quoted and sent exactly. Note that this - behavior applies even for databases which standardize upper - case names as case insensitive such as Oracle. - - The name field may be omitted at construction time and applied - later, at any time before the Column is associated with a - :class:`Table`. This is to support convenient - usage within the :mod:`~sqlalchemy.ext.declarative` extension. - - :param type\_: The column's type, indicated using an instance which - subclasses :class:`~sqlalchemy.types.AbstractType`. If no arguments - are required for the type, the class of the type can be sent - as well, e.g.:: - - # use a type with arguments - Column('data', String(50)) - - # use no arguments - Column('level', Integer) - - The ``type`` argument may be the second positional argument - or specified by keyword. - - If this column also contains a :class:`ForeignKey`, - the type argument may be left as ``None`` in which case the - type assigned will be that of the referenced column. - - :param \*args: Additional positional arguments include various - :class:`SchemaItem` derived constructs which will be applied - as options to the column. These include instances of - :class:`Constraint`, :class:`ForeignKey`, :class:`ColumnDefault`, - and :class:`Sequence`. In some cases an equivalent keyword - argument is available such as ``server_default``, ``default`` - and ``unique``. - - :param autoincrement: This flag may be set to ``False`` to - indicate an integer primary key column that should not be - considered to be the "autoincrement" column, that is - the integer primary key column which generates values - implicitly upon INSERT and whose value is usually returned - via the DBAPI cursor.lastrowid attribute. It defaults - to ``True`` to satisfy the common use case of a table - with a single integer primary key column. If the table - has a composite primary key consisting of more than one - integer column, set this flag to True only on the - column that should be considered "autoincrement". - - The setting *only* has an effect for columns which are: - - * Integer derived (i.e. INT, SMALLINT, BIGINT) - - * Part of the primary key - - * Are not referenced by any foreign keys - - * have no server side or client side defaults (with the exception - of Postgresql SERIAL). - - The setting has these two effects on columns that meet the - above criteria: - - * DDL issued for the column will include database-specific - keywords intended to signify this column as an - "autoincrement" column, such as AUTO INCREMENT on MySQL, - SERIAL on Postgresql, and IDENTITY on MS-SQL. It does - *not* issue AUTOINCREMENT for SQLite since this is a - special SQLite flag that is not required for autoincrementing - behavior. See the SQLite dialect documentation for - information on SQLite's AUTOINCREMENT. - - * The column will be considered to be available as - cursor.lastrowid or equivalent, for those dialects which - "post fetch" newly inserted identifiers after a row has - been inserted (SQLite, MySQL, MS-SQL). It does not have - any effect in this regard for databases that use sequences - to generate primary key identifiers (i.e. Firebird, Postgresql, - Oracle). - - :param default: A scalar, Python callable, or - :class:`~sqlalchemy.sql.expression.ClauseElement` representing the - *default value* for this column, which will be invoked upon insert - if this column is otherwise not specified in the VALUES clause of - the insert. This is a shortcut to using :class:`ColumnDefault` as - a positional argument. - - Contrast this argument to ``server_default`` which creates a - default generator on the database side. - - :param key: An optional string identifier which will identify this - ``Column`` object on the :class:`Table`. When a key is provided, - this is the only identifier referencing the ``Column`` within the - application, including ORM attribute mapping; the ``name`` field - is used only when rendering SQL. - - :param index: When ``True``, indicates that the column is indexed. - This is a shortcut for using a :class:`Index` construct on the - table. To specify indexes with explicit names or indexes that - contain multiple columns, use the :class:`Index` construct - instead. - - :param info: A dictionary which defaults to ``{}``. A space to store - application specific data. This must be a dictionary. - - :param nullable: If set to the default of ``True``, indicates the - column will be rendered as allowing NULL, else it's rendered as - NOT NULL. This parameter is only used when issuing CREATE TABLE - statements. - - :param onupdate: A scalar, Python callable, or - :class:`~sqlalchemy.sql.expression.ClauseElement` representing a - default value to be applied to the column within UPDATE - statements, which wil be invoked upon update if this column is not - present in the SET clause of the update. This is a shortcut to - using :class:`ColumnDefault` as a positional argument with - ``for_update=True``. - - :param primary_key: If ``True``, marks this column as a primary key - column. Multiple columns can have this flag set to specify - composite primary keys. As an alternative, the primary key of a - :class:`Table` can be specified via an explicit - :class:`PrimaryKeyConstraint` object. - - :param server_default: A :class:`FetchedValue` instance, str, Unicode - or :func:`~sqlalchemy.sql.expression.text` construct representing - the DDL DEFAULT value for the column. - - String types will be emitted as-is, surrounded by single quotes:: - - Column('x', Text, server_default="val") - - x TEXT DEFAULT 'val' - - A :func:`~sqlalchemy.sql.expression.text` expression will be - rendered as-is, without quotes:: - - Column('y', DateTime, server_default=text('NOW()'))0 - - y DATETIME DEFAULT NOW() - - Strings and text() will be converted into a :class:`DefaultClause` - object upon initialization. - - Use :class:`FetchedValue` to indicate that an already-existing - column will generate a default value on the database side which - will be available to SQLAlchemy for post-fetch after inserts. This - construct does not specify any DDL and the implementation is left - to the database, such as via a trigger. - - :param server_onupdate: A :class:`FetchedValue` instance - representing a database-side default generation function. This - indicates to SQLAlchemy that a newly generated value will be - available after updates. This construct does not specify any DDL - and the implementation is left to the database, such as via a - trigger. - - :param quote: Force quoting of this column's name on or off, - corresponding to ``True`` or ``False``. When left at its default - of ``None``, the column identifier will be quoted according to - whether the name is case sensitive (identifiers with at least one - upper case character are treated as case sensitive), or if it's a - reserved word. This flag is only needed to force quoting of a - reserved word which is not known by the SQLAlchemy dialect. - - :param unique: When ``True``, indicates that this column contains a - unique constraint, or if ``index`` is ``True`` as well, indicates - that the :class:`Index` should be created with the unique flag. - To specify multiple columns in the constraint/index or to specify - an explicit name, use the :class:`UniqueConstraint` or - :class:`Index` constructs explicitly. - - """ - - name = kwargs.pop('name', None) - type_ = kwargs.pop('type_', None) - args = list(args) - if args: - if isinstance(args[0], basestring): - if name is not None: - raise exc.ArgumentError( - "May not pass name positionally and as a keyword.") - name = args.pop(0) - if args: - coltype = args[0] - - if (isinstance(coltype, types.AbstractType) or - (isinstance(coltype, type) and - issubclass(coltype, types.AbstractType))): - if type_ is not None: - raise exc.ArgumentError( - "May not pass type_ positionally and as a keyword.") - type_ = args.pop(0) - - no_type = type_ is None - - super(Column, self).__init__(name, None, type_) - self.key = kwargs.pop('key', name) - self.primary_key = kwargs.pop('primary_key', False) - self.nullable = kwargs.pop('nullable', not self.primary_key) - self.default = kwargs.pop('default', None) - self.server_default = kwargs.pop('server_default', None) - self.server_onupdate = kwargs.pop('server_onupdate', None) - self.index = kwargs.pop('index', None) - self.unique = kwargs.pop('unique', None) - self.quote = kwargs.pop('quote', None) - self.onupdate = kwargs.pop('onupdate', None) - self.autoincrement = kwargs.pop('autoincrement', True) - self.constraints = set() - self.foreign_keys = util.OrderedSet() - self._table_events = set() - - # check if this Column is proxying another column - if '_proxies' in kwargs: - self.proxies = kwargs.pop('_proxies') - # otherwise, add DDL-related events - elif isinstance(self.type, types.SchemaType): - self.type._set_parent(self) - - if self.default is not None: - if isinstance(self.default, (ColumnDefault, Sequence)): - args.append(self.default) - else: - args.append(ColumnDefault(self.default)) - - if self.server_default is not None: - if isinstance(self.server_default, FetchedValue): - args.append(self.server_default) - else: - args.append(DefaultClause(self.server_default)) - - if self.onupdate is not None: - if isinstance(self.onupdate, (ColumnDefault, Sequence)): - args.append(self.onupdate) - else: - args.append(ColumnDefault(self.onupdate, for_update=True)) - - if self.server_onupdate is not None: - if isinstance(self.server_onupdate, FetchedValue): - args.append(self.server_default) - else: - args.append(DefaultClause(self.server_onupdate, - for_update=True)) - self._init_items(*args) - - if not self.foreign_keys and no_type: - raise exc.ArgumentError("'type' is required on Column objects " - "which have no foreign keys.") - util.set_creation_order(self) - - if 'info' in kwargs: - self.info = kwargs.pop('info') - - if kwargs: - raise exc.ArgumentError( - "Unknown arguments passed to Column: " + repr(kwargs.keys())) - - def __str__(self): - if self.name is None: - return "(no name)" - elif self.table is not None: - if self.table.named_with_column: - return (self.table.description + "." + self.description) - else: - return self.description - else: - return self.description - - def references(self, column): - """Return True if this Column references the given column via foreign key.""" - for fk in self.foreign_keys: - if fk.references(column.table): - return True - else: - return False - - def append_foreign_key(self, fk): - fk._set_parent(self) - - def __repr__(self): - kwarg = [] - if self.key != self.name: - kwarg.append('key') - if self.primary_key: - kwarg.append('primary_key') - if not self.nullable: - kwarg.append('nullable') - if self.onupdate: - kwarg.append('onupdate') - if self.default: - kwarg.append('default') - if self.server_default: - kwarg.append('server_default') - return "Column(%s)" % ', '.join( - [repr(self.name)] + [repr(self.type)] + - [repr(x) for x in self.foreign_keys if x is not None] + - [repr(x) for x in self.constraints] + - [(self.table is not None and "table=<%s>" % self.table.description or "")] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]) - - def _set_parent(self, table): - if self.name is None: - raise exc.ArgumentError( - "Column must be constructed with a name or assign .name " - "before adding to a Table.") - if self.key is None: - self.key = self.name - - if getattr(self, 'table', None) is not None: - raise exc.ArgumentError("this Column already has a table!") - - if self.key in table._columns: - col = table._columns.get(self.key) - for fk in col.foreign_keys: - col.foreign_keys.remove(fk) - table.foreign_keys.remove(fk) - table.constraints.remove(fk.constraint) - - table._columns.replace(self) - - if self.primary_key: - table.primary_key._replace(self) - elif self.key in table.primary_key: - raise exc.ArgumentError( - "Trying to redefine primary-key column '%s' as a " - "non-primary-key column on table '%s'" % ( - self.key, table.fullname)) - self.table = table - - if self.index: - if isinstance(self.index, basestring): - raise exc.ArgumentError( - "The 'index' keyword argument on Column is boolean only. " - "To create indexes with a specific name, create an " - "explicit Index object external to the Table.") - Index('ix_%s' % self._label, self, unique=self.unique) - elif self.unique: - if isinstance(self.unique, basestring): - raise exc.ArgumentError( - "The 'unique' keyword argument on Column is boolean only. " - "To create unique constraints or indexes with a specific " - "name, append an explicit UniqueConstraint to the Table's " - "list of elements, or create an explicit Index object " - "external to the Table.") - table.append_constraint(UniqueConstraint(self.key)) - - for fn in self._table_events: - fn(table, self) - del self._table_events - - def _on_table_attach(self, fn): - if self.table is not None: - fn(self.table, self) - else: - self._table_events.add(fn) - - def copy(self, **kw): - """Create a copy of this ``Column``, unitialized. - - This is used in ``Table.tometadata``. - - """ - - # Constraint objects plus non-constraint-bound ForeignKey objects - args = \ - [c.copy(**kw) for c in self.constraints] + \ - [c.copy(**kw) for c in self.foreign_keys if not c.constraint] - - c = Column( - name=self.name, - type_=self.type, - key = self.key, - primary_key = self.primary_key, - nullable = self.nullable, - quote=self.quote, - index=self.index, - autoincrement=self.autoincrement, - default=self.default, - server_default=self.server_default, - onupdate=self.onupdate, - server_onupdate=self.server_onupdate, - *args - ) - if hasattr(self, '_table_events'): - c._table_events = list(self._table_events) - return c - - def _make_proxy(self, selectable, name=None): - """Create a *proxy* for this column. - - This is a copy of this ``Column`` referenced by a different parent - (such as an alias or select statement). The column should - be used only in select scenarios, as its full DDL/default - information is not transferred. - - """ - fk = [ForeignKey(f.column) for f in self.foreign_keys] - c = Column( - name or self.name, - self.type, - key = name or self.key, - primary_key = self.primary_key, - nullable = self.nullable, - quote=self.quote, _proxies=[self], *fk) - c.table = selectable - selectable.columns.add(c) - if self.primary_key: - selectable.primary_key.add(c) - for fn in c._table_events: - fn(selectable, c) - del c._table_events - return c - - def get_children(self, schema_visitor=False, **kwargs): - if schema_visitor: - return [x for x in (self.default, self.onupdate) if x is not None] + \ - list(self.foreign_keys) + list(self.constraints) - else: - return expression.ColumnClause.get_children(self, **kwargs) - - -class ForeignKey(SchemaItem): - """Defines a dependency between two columns. - - ``ForeignKey`` is specified as an argument to a :class:`Column` object, - e.g.:: - - t = Table("remote_table", metadata, - Column("remote_id", ForeignKey("main_table.id")) - ) - - Note that ``ForeignKey`` is only a marker object that defines - a dependency between two columns. The actual constraint - is in all cases represented by the :class:`ForeignKeyConstraint` - object. This object will be generated automatically when - a ``ForeignKey`` is associated with a :class:`Column` which - in turn is associated with a :class:`Table`. Conversely, - when :class:`ForeignKeyConstraint` is applied to a :class:`Table`, - ``ForeignKey`` markers are automatically generated to be - present on each associated :class:`Column`, which are also - associated with the constraint object. - - Note that you cannot define a "composite" foreign key constraint, - that is a constraint between a grouping of multiple parent/child - columns, using ``ForeignKey`` objects. To define this grouping, - the :class:`ForeignKeyConstraint` object must be used, and applied - to the :class:`Table`. The associated ``ForeignKey`` objects - are created automatically. - - The ``ForeignKey`` objects associated with an individual - :class:`Column` object are available in the `foreign_keys` collection - of that column. - - Further examples of foreign key configuration are in - :ref:`metadata_foreignkeys`. - - """ - - __visit_name__ = 'foreign_key' - - def __init__(self, column, _constraint=None, use_alter=False, name=None, - onupdate=None, ondelete=None, deferrable=None, - initially=None, link_to_name=False): - """ - Construct a column-level FOREIGN KEY. - - The :class:`ForeignKey` object when constructed generates a - :class:`ForeignKeyConstraint` which is associated with the parent - :class:`Table` object's collection of constraints. - - :param column: A single target column for the key relationship. A - :class:`Column` object or a column name as a string: - ``tablename.columnkey`` or ``schema.tablename.columnkey``. - ``columnkey`` is the ``key`` which has been assigned to the column - (defaults to the column name itself), unless ``link_to_name`` is - ``True`` in which case the rendered name of the column is used. - - :param name: Optional string. An in-database name for the key if - `constraint` is not provided. - - :param onupdate: Optional string. If set, emit ON UPDATE when - issuing DDL for this constraint. Typical values include CASCADE, - DELETE and RESTRICT. - - :param ondelete: Optional string. If set, emit ON DELETE when - issuing DDL for this constraint. Typical values include CASCADE, - DELETE and RESTRICT. - - :param deferrable: Optional bool. If set, emit DEFERRABLE or NOT - DEFERRABLE when issuing DDL for this constraint. - - :param initially: Optional string. If set, emit INITIALLY when - issuing DDL for this constraint. - - :param link_to_name: if True, the string name given in ``column`` is - the rendered name of the referenced column, not its locally - assigned ``key``. - - :param use_alter: passed to the underlying - :class:`ForeignKeyConstraint` to indicate the constraint should be - generated/dropped externally from the CREATE TABLE/ DROP TABLE - statement. See that classes' constructor for details. - - """ - - self._colspec = column - - # the linked ForeignKeyConstraint. - # ForeignKey will create this when parent Column - # is attached to a Table, *or* ForeignKeyConstraint - # object passes itself in when creating ForeignKey - # markers. - self.constraint = _constraint - - - self.use_alter = use_alter - self.name = name - self.onupdate = onupdate - self.ondelete = ondelete - self.deferrable = deferrable - self.initially = initially - self.link_to_name = link_to_name - - def __repr__(self): - return "ForeignKey(%r)" % self._get_colspec() - - def copy(self, schema=None): - """Produce a copy of this ForeignKey object.""" - - return ForeignKey( - self._get_colspec(schema=schema), - use_alter=self.use_alter, - name=self.name, - onupdate=self.onupdate, - ondelete=self.ondelete, - deferrable=self.deferrable, - initially=self.initially, - link_to_name=self.link_to_name - ) - - def _get_colspec(self, schema=None): - if schema: - return schema + "." + self.column.table.name + "." + self.column.key - elif isinstance(self._colspec, basestring): - return self._colspec - elif hasattr(self._colspec, '__clause_element__'): - _column = self._colspec.__clause_element__() - else: - _column = self._colspec - - return "%s.%s" % (_column.table.fullname, _column.key) - - target_fullname = property(_get_colspec) - - def references(self, table): - """Return True if the given table is referenced by this ForeignKey.""" - return table.corresponding_column(self.column) is not None - - def get_referent(self, table): - """Return the column in the given table referenced by this ForeignKey. - - Returns None if this ``ForeignKey`` does not reference the given table. - - """ - - return table.corresponding_column(self.column) - - @util.memoized_property - def column(self): - # ForeignKey inits its remote column as late as possible, so tables - # can be defined without dependencies - if isinstance(self._colspec, basestring): - # locate the parent table this foreign key is attached to. we - # use the "original" column which our parent column represents - # (its a list of columns/other ColumnElements if the parent - # table is a UNION) - for c in self.parent.base_columns: - if isinstance(c, Column): - parenttable = c.table - break - else: - raise exc.ArgumentError( - "Parent column '%s' does not descend from a " - "table-attached Column" % str(self.parent)) - - m = self._colspec.split('.') - - if m is None: - raise exc.ArgumentError( - "Invalid foreign key column specification: %s" % - self._colspec) - - # A FK between column 'bar' and table 'foo' can be - # specified as 'foo', 'foo.bar', 'dbo.foo.bar', - # 'otherdb.dbo.foo.bar'. Once we have the column name and - # the table name, treat everything else as the schema - # name. Some databases (e.g. Sybase) support - # inter-database foreign keys. See tickets#1341 and -- - # indirectly related -- Ticket #594. This assumes that '.' - # will never appear *within* any component of the FK. - - (schema, tname, colname) = (None, None, None) - if (len(m) == 1): - tname = m.pop() - else: - colname = m.pop() - tname = m.pop() - - if (len(m) > 0): - schema = '.'.join(m) - - if _get_table_key(tname, schema) not in parenttable.metadata: - raise exc.NoReferencedTableError( - "Could not find table '%s' with which to generate a " - "foreign key" % tname) - table = Table(tname, parenttable.metadata, - mustexist=True, schema=schema) - - _column = None - if colname is None: - # colname is None in the case that ForeignKey argument - # was specified as table name only, in which case we - # match the column name to the same column on the - # parent. - key = self.parent - _column = table.c.get(self.parent.key, None) - elif self.link_to_name: - key = colname - for c in table.c: - if c.name == colname: - _column = c - else: - key = colname - _column = table.c.get(colname, None) - - if _column is None: - raise exc.NoReferencedColumnError( - "Could not create ForeignKey '%s' on table '%s': " - "table '%s' has no column named '%s'" % ( - self._colspec, parenttable.name, table.name, key)) - - elif hasattr(self._colspec, '__clause_element__'): - _column = self._colspec.__clause_element__() - else: - _column = self._colspec - - # propagate TypeEngine to parent if it didn't have one - if isinstance(self.parent.type, types.NullType): - self.parent.type = _column.type - return _column - - def _set_parent(self, column): - if hasattr(self, 'parent'): - if self.parent is column: - return - raise exc.InvalidRequestError("This ForeignKey already has a parent !") - self.parent = column - self.parent.foreign_keys.add(self) - self.parent._on_table_attach(self._set_table) - - def _set_table(self, table, column): - # standalone ForeignKey - create ForeignKeyConstraint - # on the hosting Table when attached to the Table. - if self.constraint is None and isinstance(table, Table): - self.constraint = ForeignKeyConstraint( - [], [], use_alter=self.use_alter, name=self.name, - onupdate=self.onupdate, ondelete=self.ondelete, - deferrable=self.deferrable, initially=self.initially, - ) - self.constraint._elements[self.parent] = self - self.constraint._set_parent(table) - table.foreign_keys.add(self) - -class DefaultGenerator(SchemaItem): - """Base class for column *default* values.""" - - __visit_name__ = 'default_generator' - - is_sequence = False - - def __init__(self, for_update=False): - self.for_update = for_update - - def _set_parent(self, column): - self.column = column - if self.for_update: - self.column.onupdate = self - else: - self.column.default = self - - def execute(self, bind=None, **kwargs): - if bind is None: - bind = _bind_or_error(self) - return bind._execute_default(self, **kwargs) - - @property - def bind(self): - """Return the connectable associated with this default.""" - if getattr(self, 'column', None) is not None: - return self.column.table.bind - else: - return None - - def __repr__(self): - return "DefaultGenerator()" - - -class ColumnDefault(DefaultGenerator): - """A plain default value on a column. - - This could correspond to a constant, a callable function, or a SQL clause. - """ - - def __init__(self, arg, **kwargs): - super(ColumnDefault, self).__init__(**kwargs) - if isinstance(arg, FetchedValue): - raise exc.ArgumentError( - "ColumnDefault may not be a server-side default type.") - if util.callable(arg): - arg = self._maybe_wrap_callable(arg) - self.arg = arg - - @util.memoized_property - def is_callable(self): - return util.callable(self.arg) - - @util.memoized_property - def is_clause_element(self): - return isinstance(self.arg, expression.ClauseElement) - - @util.memoized_property - def is_scalar(self): - return not self.is_callable and not self.is_clause_element and not self.is_sequence - - def _maybe_wrap_callable(self, fn): - """Backward compat: Wrap callables that don't accept a context.""" - - if inspect.isfunction(fn): - inspectable = fn - elif inspect.isclass(fn): - inspectable = fn.__init__ - elif hasattr(fn, '__call__'): - inspectable = fn.__call__ - else: - # probably not inspectable, try anyways. - inspectable = fn - try: - argspec = inspect.getargspec(inspectable) - except TypeError: - return lambda ctx: fn() - - positionals = len(argspec[0]) - - # Py3K compat - no unbound methods - if inspect.ismethod(inspectable) or inspect.isclass(fn): - positionals -= 1 - - if positionals == 0: - return lambda ctx: fn() - - defaulted = argspec[3] is not None and len(argspec[3]) or 0 - if positionals - defaulted > 1: - raise exc.ArgumentError( - "ColumnDefault Python function takes zero or one " - "positional arguments") - return fn - - def _visit_name(self): - if self.for_update: - return "column_onupdate" - else: - return "column_default" - __visit_name__ = property(_visit_name) - - def __repr__(self): - return "ColumnDefault(%r)" % self.arg - -class Sequence(DefaultGenerator): - """Represents a named database sequence.""" - - __visit_name__ = 'sequence' - - is_sequence = True - - def __init__(self, name, start=None, increment=None, schema=None, - optional=False, quote=None, metadata=None, for_update=False): - super(Sequence, self).__init__(for_update=for_update) - self.name = name - self.start = start - self.increment = increment - self.optional = optional - self.quote = quote - self.schema = schema - self.metadata = metadata - - @util.memoized_property - def is_callable(self): - return False - - @util.memoized_property - def is_clause_element(self): - return False - - def __repr__(self): - return "Sequence(%s)" % ', '.join( - [repr(self.name)] + - ["%s=%s" % (k, repr(getattr(self, k))) - for k in ['start', 'increment', 'optional']]) - - def _set_parent(self, column): - super(Sequence, self)._set_parent(column) - column._on_table_attach(self._set_table) - - def _set_table(self, table, column): - self.metadata = table.metadata - - @property - def bind(self): - if self.metadata: - return self.metadata.bind - else: - return None - - def create(self, bind=None, checkfirst=True): - """Creates this sequence in the database.""" - - if bind is None: - bind = _bind_or_error(self) - bind.create(self, checkfirst=checkfirst) - - def drop(self, bind=None, checkfirst=True): - """Drops this sequence from the database.""" - - if bind is None: - bind = _bind_or_error(self) - bind.drop(self, checkfirst=checkfirst) - - -class FetchedValue(object): - """A default that takes effect on the database side.""" - - def __init__(self, for_update=False): - self.for_update = for_update - - def _set_parent(self, column): - self.column = column - if self.for_update: - self.column.server_onupdate = self - else: - self.column.server_default = self - - def __repr__(self): - return 'FetchedValue(for_update=%r)' % self.for_update - - -class DefaultClause(FetchedValue): - """A DDL-specified DEFAULT column value.""" - - def __init__(self, arg, for_update=False): - util.assert_arg_type(arg, (basestring, - expression.ClauseElement, - expression._TextClause), 'arg') - super(DefaultClause, self).__init__(for_update) - self.arg = arg - - def __repr__(self): - return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update) - -class PassiveDefault(DefaultClause): - def __init__(self, *arg, **kw): - util.warn_deprecated("PassiveDefault is deprecated. Use DefaultClause.") - DefaultClause.__init__(self, *arg, **kw) - -class Constraint(SchemaItem): - """A table-level SQL constraint.""" - - __visit_name__ = 'constraint' - - def __init__(self, name=None, deferrable=None, initially=None, - _create_rule=None): - """Create a SQL constraint. - - name - Optional, the in-database name of this ``Constraint``. - - deferrable - Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when - issuing DDL for this constraint. - - initially - Optional string. If set, emit INITIALLY when issuing DDL - for this constraint. - - _create_rule - a callable which is passed the DDLCompiler object during - compilation. Returns True or False to signal inline generation of - this Constraint. - - The AddConstraint and DropConstraint DDL constructs provide - DDLElement's more comprehensive "conditional DDL" approach that is - passed a database connection when DDL is being issued. _create_rule - is instead called during any CREATE TABLE compilation, where there - may not be any transaction/connection in progress. However, it - allows conditional compilation of the constraint even for backends - which do not support addition of constraints through ALTER TABLE, - which currently includes SQLite. - - _create_rule is used by some types to create constraints. - Currently, its call signature is subject to change at any time. - - """ - - self.name = name - self.deferrable = deferrable - self.initially = initially - self._create_rule = _create_rule - - @property - def table(self): - try: - if isinstance(self.parent, Table): - return self.parent - except AttributeError: - pass - raise exc.InvalidRequestError("This constraint is not bound to a table. Did you mean to call table.add_constraint(constraint) ?") - - def _set_parent(self, parent): - self.parent = parent - parent.constraints.add(self) - - def copy(self, **kw): - raise NotImplementedError() - -class ColumnCollectionConstraint(Constraint): - """A constraint that proxies a ColumnCollection.""" - - def __init__(self, *columns, **kw): - """ - \*columns - A sequence of column names or Column objects. - - name - Optional, the in-database name of this constraint. - - deferrable - Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when - issuing DDL for this constraint. - - initially - Optional string. If set, emit INITIALLY when issuing DDL - for this constraint. - - """ - super(ColumnCollectionConstraint, self).__init__(**kw) - self.columns = expression.ColumnCollection() - self._pending_colargs = [_to_schema_column_or_string(c) for c in columns] - if self._pending_colargs and \ - isinstance(self._pending_colargs[0], Column) and \ - self._pending_colargs[0].table is not None: - self._set_parent(self._pending_colargs[0].table) - - def _set_parent(self, table): - super(ColumnCollectionConstraint, self)._set_parent(table) - for col in self._pending_colargs: - if isinstance(col, basestring): - col = table.c[col] - self.columns.add(col) - - def __contains__(self, x): - return x in self.columns - - def copy(self, **kw): - return self.__class__(name=self.name, deferrable=self.deferrable, - initially=self.initially, *self.columns.keys()) - - def contains_column(self, col): - return self.columns.contains_column(col) - - def __iter__(self): - return iter(self.columns) - - def __len__(self): - return len(self.columns) - - -class CheckConstraint(Constraint): - """A table- or column-level CHECK constraint. - - Can be included in the definition of a Table or Column. - """ - - def __init__(self, sqltext, name=None, deferrable=None, - initially=None, table=None, _create_rule=None): - """Construct a CHECK constraint. - - sqltext - A string containing the constraint definition, which will be used - verbatim, or a SQL expression construct. - - name - Optional, the in-database name of the constraint. - - deferrable - Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when - issuing DDL for this constraint. - - initially - Optional string. If set, emit INITIALLY when issuing DDL - for this constraint. - - """ - - super(CheckConstraint, self).__init__(name, deferrable, initially, _create_rule) - self.sqltext = expression._literal_as_text(sqltext) - if table is not None: - self._set_parent(table) - - def __visit_name__(self): - if isinstance(self.parent, Table): - return "check_constraint" - else: - return "column_check_constraint" - __visit_name__ = property(__visit_name__) - - def copy(self, **kw): - return CheckConstraint(self.sqltext, name=self.name) - -class ForeignKeyConstraint(Constraint): - """A table-level FOREIGN KEY constraint. - - Defines a single column or composite FOREIGN KEY ... REFERENCES - constraint. For a no-frills, single column foreign key, adding a - :class:`ForeignKey` to the definition of a :class:`Column` is a shorthand - equivalent for an unnamed, single column :class:`ForeignKeyConstraint`. - - Examples of foreign key configuration are in :ref:`metadata_foreignkeys`. - - """ - __visit_name__ = 'foreign_key_constraint' - - def __init__(self, columns, refcolumns, name=None, onupdate=None, - ondelete=None, deferrable=None, initially=None, use_alter=False, - link_to_name=False, table=None): - """Construct a composite-capable FOREIGN KEY. - - :param columns: A sequence of local column names. The named columns - must be defined and present in the parent Table. The names should - match the ``key`` given to each column (defaults to the name) unless - ``link_to_name`` is True. - - :param refcolumns: A sequence of foreign column names or Column - objects. The columns must all be located within the same Table. - - :param name: Optional, the in-database name of the key. - - :param onupdate: Optional string. If set, emit ON UPDATE when - issuing DDL for this constraint. Typical values include CASCADE, - DELETE and RESTRICT. - - :param ondelete: Optional string. If set, emit ON DELETE when - issuing DDL for this constraint. Typical values include CASCADE, - DELETE and RESTRICT. - - :param deferrable: Optional bool. If set, emit DEFERRABLE or NOT - DEFERRABLE when issuing DDL for this constraint. - - :param initially: Optional string. If set, emit INITIALLY when - issuing DDL for this constraint. - - :param link_to_name: if True, the string name given in ``column`` is - the rendered name of the referenced column, not its locally assigned - ``key``. - - :param use_alter: If True, do not emit the DDL for this constraint as - part of the CREATE TABLE definition. Instead, generate it via an - ALTER TABLE statement issued after the full collection of tables - have been created, and drop it via an ALTER TABLE statement before - the full collection of tables are dropped. This is shorthand for the - usage of :class:`AddConstraint` and :class:`DropConstraint` applied - as "after-create" and "before-drop" events on the MetaData object. - This is normally used to generate/drop constraints on objects that - are mutually dependent on each other. - - """ - super(ForeignKeyConstraint, self).__init__(name, deferrable, initially) - - self.onupdate = onupdate - self.ondelete = ondelete - self.link_to_name = link_to_name - if self.name is None and use_alter: - raise exc.ArgumentError("Alterable Constraint requires a name") - self.use_alter = use_alter - - self._elements = util.OrderedDict() - - # standalone ForeignKeyConstraint - create - # associated ForeignKey objects which will be applied to hosted - # Column objects (in col.foreign_keys), either now or when attached - # to the Table for string-specified names - for col, refcol in zip(columns, refcolumns): - self._elements[col] = ForeignKey( - refcol, - _constraint=self, - name=self.name, - onupdate=self.onupdate, - ondelete=self.ondelete, - use_alter=self.use_alter, - link_to_name=self.link_to_name - ) - - if table: - self._set_parent(table) - - @property - def columns(self): - return self._elements.keys() - - @property - def elements(self): - return self._elements.values() - - def _set_parent(self, table): - super(ForeignKeyConstraint, self)._set_parent(table) - for col, fk in self._elements.iteritems(): - # string-specified column names now get - # resolved to Column objects - if isinstance(col, basestring): - col = table.c[col] - fk._set_parent(col) - - if self.use_alter: - def supports_alter(ddl, event, schema_item, bind, **kw): - return table in set(kw['tables']) and bind.dialect.supports_alter - AddConstraint(self, on=supports_alter).execute_at('after-create', table.metadata) - DropConstraint(self, on=supports_alter).execute_at('before-drop', table.metadata) - - def copy(self, **kw): - return ForeignKeyConstraint( - [x.parent.name for x in self._elements.values()], - [x._get_colspec(**kw) for x in self._elements.values()], - name=self.name, - onupdate=self.onupdate, - ondelete=self.ondelete, - use_alter=self.use_alter, - deferrable=self.deferrable, - initially=self.initially, - link_to_name=self.link_to_name - ) - -class PrimaryKeyConstraint(ColumnCollectionConstraint): - """A table-level PRIMARY KEY constraint. - - Defines a single column or composite PRIMARY KEY constraint. For a - no-frills primary key, adding ``primary_key=True`` to one or more - ``Column`` definitions is a shorthand equivalent for an unnamed single- or - multiple-column PrimaryKeyConstraint. - """ - - __visit_name__ = 'primary_key_constraint' - - def _set_parent(self, table): - super(PrimaryKeyConstraint, self)._set_parent(table) - table._set_primary_key(self) - - def _replace(self, col): - self.columns.replace(col) - -class UniqueConstraint(ColumnCollectionConstraint): - """A table-level UNIQUE constraint. - - Defines a single column or composite UNIQUE constraint. For a no-frills, - single column constraint, adding ``unique=True`` to the ``Column`` - definition is a shorthand equivalent for an unnamed, single column - UniqueConstraint. - """ - - __visit_name__ = 'unique_constraint' - -class Index(SchemaItem): - """A table-level INDEX. - - Defines a composite (one or more column) INDEX. For a no-frills, single - column index, adding ``index=True`` to the ``Column`` definition is - a shorthand equivalent for an unnamed, single column Index. - """ - - __visit_name__ = 'index' - - def __init__(self, name, *columns, **kwargs): - """Construct an index object. - - Arguments are: - - name - The name of the index - - \*columns - Columns to include in the index. All columns must belong to the same - table. - - \**kwargs - Keyword arguments include: - - unique - Defaults to False: create a unique index. - - postgresql_where - Defaults to None: create a partial index when using PostgreSQL - """ - - self.name = name - self.columns = expression.ColumnCollection() - self.table = None - self.unique = kwargs.pop('unique', False) - self.kwargs = kwargs - - for column in columns: - column = _to_schema_column(column) - if self.table is None: - self._set_parent(column.table) - elif column.table != self.table: - # all columns muse be from same table - raise exc.ArgumentError( - "All index columns must be from same table. " - "%s is from %s not %s" % (column, column.table, self.table)) - self.columns.add(column) - - def _set_parent(self, table): - self.table = table - table.indexes.add(self) - - @property - def bind(self): - """Return the connectable associated with this Index.""" - - return self.table.bind - - def create(self, bind=None): - if bind is None: - bind = _bind_or_error(self) - bind.create(self) - return self - - def drop(self, bind=None): - if bind is None: - bind = _bind_or_error(self) - bind.drop(self) - - def __repr__(self): - return 'Index("%s", %s%s)' % (self.name, - ', '.join(repr(c) for c in self.columns), - (self.unique and ', unique=True') or '') - -class MetaData(SchemaItem): - """A collection of Tables and their associated schema constructs. - - Holds a collection of Tables and an optional binding to an ``Engine`` or - ``Connection``. If bound, the :class:`~sqlalchemy.schema.Table` objects - in the collection and their columns may participate in implicit SQL - execution. - - The `Table` objects themselves are stored in the `metadata.tables` - dictionary. - - The ``bind`` property may be assigned to dynamically. A common pattern is - to start unbound and then bind later when an engine is available:: - - metadata = MetaData() - # define tables - Table('mytable', metadata, ...) - # connect to an engine later, perhaps after loading a URL from a - # configuration file - metadata.bind = an_engine - - MetaData is a thread-safe object after tables have been explicitly defined - or loaded via reflection. - - .. index:: - single: thread safety; MetaData - - """ - - __visit_name__ = 'metadata' - - ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop') - - def __init__(self, bind=None, reflect=False): - """Create a new MetaData object. - - bind - An Engine or Connection to bind to. May also be a string or URL - instance, these are passed to create_engine() and this MetaData will - be bound to the resulting engine. - - reflect - Optional, automatically load all tables from the bound database. - Defaults to False. ``bind`` is required when this option is set. - For finer control over loaded tables, use the ``reflect`` method of - ``MetaData``. - - """ - self.tables = {} - self.bind = bind - self.metadata = self - self.ddl_listeners = util.defaultdict(list) - if reflect: - if not bind: - raise exc.ArgumentError( - "A bind must be supplied in conjunction with reflect=True") - self.reflect() - - def __repr__(self): - return 'MetaData(%r)' % self.bind - - def __contains__(self, table_or_key): - if not isinstance(table_or_key, basestring): - table_or_key = table_or_key.key - return table_or_key in self.tables - - def __getstate__(self): - return {'tables': self.tables} - - def __setstate__(self, state): - self.tables = state['tables'] - self._bind = None - - def is_bound(self): - """True if this MetaData is bound to an Engine or Connection.""" - - return self._bind is not None - - def bind(self): - """An Engine or Connection to which this MetaData is bound. - - This property may be assigned an ``Engine`` or ``Connection``, or - assigned a string or URL to automatically create a basic ``Engine`` - for this bind with ``create_engine()``. - - """ - return self._bind - - def _bind_to(self, bind): - """Bind this MetaData to an Engine, Connection, string or URL.""" - - global URL - if URL is None: - from sqlalchemy.engine.url import URL - - if isinstance(bind, (basestring, URL)): - from sqlalchemy import create_engine - self._bind = create_engine(bind) - else: - self._bind = bind - bind = property(bind, _bind_to) - - def clear(self): - """Clear all Table objects from this MetaData.""" - # TODO: why have clear()/remove() but not all - # other accesors/mutators for the tables dict ? - self.tables.clear() - - def remove(self, table): - """Remove the given Table object from this MetaData.""" - - # TODO: scan all other tables and remove FK _column - del self.tables[table.key] - - @property - def sorted_tables(self): - """Returns a list of ``Table`` objects sorted in order of - dependency. - """ - from sqlalchemy.sql.util import sort_tables - return sort_tables(self.tables.itervalues()) - - def reflect(self, bind=None, schema=None, only=None): - """Load all available table definitions from the database. - - Automatically creates ``Table`` entries in this ``MetaData`` for any - table available in the database but not yet present in the - ``MetaData``. May be called multiple times to pick up tables recently - added to the database, however no special action is taken if a table - in this ``MetaData`` no longer exists in the database. - - bind - A :class:`~sqlalchemy.engine.base.Connectable` used to access the database; if None, uses the - existing bind on this ``MetaData``, if any. - - schema - Optional, query and reflect tables from an alterate schema. - - only - Optional. Load only a sub-set of available named tables. May be - specified as a sequence of names or a callable. - - If a sequence of names is provided, only those tables will be - reflected. An error is raised if a table is requested but not - available. Named tables already present in this ``MetaData`` are - ignored. - - If a callable is provided, it will be used as a boolean predicate to - filter the list of potential table names. The callable is called - with a table name and this ``MetaData`` instance as positional - arguments and should return a true value for any table to reflect. - - """ - reflect_opts = {'autoload': True} - if bind is None: - bind = _bind_or_error(self) - conn = None - else: - reflect_opts['autoload_with'] = bind - conn = bind.contextual_connect() - - if schema is not None: - reflect_opts['schema'] = schema - - available = util.OrderedSet(bind.engine.table_names(schema, - connection=conn)) - current = set(self.tables.iterkeys()) - - if only is None: - load = [name for name in available if name not in current] - elif util.callable(only): - load = [name for name in available - if name not in current and only(name, self)] - else: - missing = [name for name in only if name not in available] - if missing: - s = schema and (" schema '%s'" % schema) or '' - raise exc.InvalidRequestError( - 'Could not reflect: requested table(s) not available ' - 'in %s%s: (%s)' % (bind.engine.url, s, ', '.join(missing))) - load = [name for name in only if name not in current] - - for name in load: - Table(name, self, **reflect_opts) - - def append_ddl_listener(self, event, listener): - """Append a DDL event listener to this ``MetaData``. - - The ``listener`` callable will be triggered when this ``MetaData`` is - involved in DDL creates or drops, and will be invoked either before - all Table-related actions or after. - - Arguments are: - - event - One of ``MetaData.ddl_events``; 'before-create', 'after-create', - 'before-drop' or 'after-drop'. - listener - A callable, invoked with three positional arguments: - - event - The event currently being handled - target - The ``MetaData`` object being operated upon - bind - The ``Connection`` bueing used for DDL execution. - - Listeners are added to the MetaData's ``ddl_listeners`` attribute. - - Note: MetaData listeners are invoked even when ``Tables`` are created - in isolation. This may change in a future release. I.e.:: - - # triggers all MetaData and Table listeners: - metadata.create_all() - - # triggers MetaData listeners too: - some.table.create() - - """ - if event not in self.ddl_events: - raise LookupError(event) - self.ddl_listeners[event].append(listener) - - def create_all(self, bind=None, tables=None, checkfirst=True): - """Create all tables stored in this metadata. - - Conditional by default, will not attempt to recreate tables already - present in the target database. - - bind - A :class:`~sqlalchemy.engine.base.Connectable` used to access the database; if None, uses the - existing bind on this ``MetaData``, if any. - - tables - Optional list of ``Table`` objects, which is a subset of the total - tables in the ``MetaData`` (others are ignored). - - checkfirst - Defaults to True, don't issue CREATEs for tables already present - in the target database. - - """ - if bind is None: - bind = _bind_or_error(self) - bind.create(self, checkfirst=checkfirst, tables=tables) - - def drop_all(self, bind=None, tables=None, checkfirst=True): - """Drop all tables stored in this metadata. - - Conditional by default, will not attempt to drop tables not present in - the target database. - - bind - A :class:`~sqlalchemy.engine.base.Connectable` used to access the database; if None, uses - the existing bind on this ``MetaData``, if any. - - tables - Optional list of ``Table`` objects, which is a subset of the - total tables in the ``MetaData`` (others are ignored). - - checkfirst - Defaults to True, only issue DROPs for tables confirmed to be present - in the target database. - - """ - if bind is None: - bind = _bind_or_error(self) - bind.drop(self, checkfirst=checkfirst, tables=tables) - -class ThreadLocalMetaData(MetaData): - """A MetaData variant that presents a different ``bind`` in every thread. - - Makes the ``bind`` property of the MetaData a thread-local value, allowing - this collection of tables to be bound to different ``Engine`` - implementations or connections in each thread. - - The ThreadLocalMetaData starts off bound to None in each thread. Binds - must be made explicitly by assigning to the ``bind`` property or using - ``connect()``. You can also re-bind dynamically multiple times per - thread, just like a regular ``MetaData``. - - """ - - __visit_name__ = 'metadata' - - def __init__(self): - """Construct a ThreadLocalMetaData.""" - - self.context = util.threading.local() - self.__engines = {} - super(ThreadLocalMetaData, self).__init__() - - def bind(self): - """The bound Engine or Connection for this thread. - - This property may be assigned an Engine or Connection, or assigned a - string or URL to automatically create a basic Engine for this bind - with ``create_engine()``.""" - - return getattr(self.context, '_engine', None) - - def _bind_to(self, bind): - """Bind to a Connectable in the caller's thread.""" - - global URL - if URL is None: - from sqlalchemy.engine.url import URL - - if isinstance(bind, (basestring, URL)): - try: - self.context._engine = self.__engines[bind] - except KeyError: - from sqlalchemy import create_engine - e = create_engine(bind) - self.__engines[bind] = e - self.context._engine = e - else: - # TODO: this is squirrely. we shouldnt have to hold onto engines - # in a case like this - if bind not in self.__engines: - self.__engines[bind] = bind - self.context._engine = bind - - bind = property(bind, _bind_to) - - def is_bound(self): - """True if there is a bind for this thread.""" - return (hasattr(self.context, '_engine') and - self.context._engine is not None) - - def dispose(self): - """Dispose all bound engines, in all thread contexts.""" - - for e in self.__engines.itervalues(): - if hasattr(e, 'dispose'): - e.dispose() - -class SchemaVisitor(visitors.ClauseVisitor): - """Define the visiting for ``SchemaItem`` objects.""" - - __traverse_options__ = {'schema_visitor':True} - - -class DDLElement(expression.Executable, expression.ClauseElement): - """Base class for DDL expression constructs.""" - - _execution_options = expression.Executable.\ - _execution_options.union({'autocommit':True}) - - target = None - on = None - - def execute(self, bind=None, target=None): - """Execute this DDL immediately. - - Executes the DDL statement in isolation using the supplied - :class:`~sqlalchemy.engine.base.Connectable` or :class:`~sqlalchemy.engine.base.Connectable` assigned to the ``.bind`` property, - if not supplied. If the DDL has a conditional ``on`` criteria, it - will be invoked with None as the event. - - bind - Optional, an ``Engine`` or ``Connection``. If not supplied, a - valid :class:`~sqlalchemy.engine.base.Connectable` must be present in the ``.bind`` property. - - target - Optional, defaults to None. The target SchemaItem for the - execute call. Will be passed to the ``on`` callable if any, - and may also provide string expansion data for the - statement. See ``execute_at`` for more information. - """ - - if bind is None: - bind = _bind_or_error(self) - - if self._should_execute(None, target, bind): - return bind.execute(self.against(target)) - else: - bind.engine.logger.info("DDL execution skipped, criteria not met.") - - def execute_at(self, event, target): - """Link execution of this DDL to the DDL lifecycle of a SchemaItem. - - Links this ``DDLElement`` to a ``Table`` or ``MetaData`` instance, executing - it when that schema item is created or dropped. The DDL statement - will be executed using the same Connection and transactional context - as the Table create/drop itself. The ``.bind`` property of this - statement is ignored. - - event - One of the events defined in the schema item's ``.ddl_events``; - e.g. 'before-create', 'after-create', 'before-drop' or 'after-drop' - - target - The Table or MetaData instance for which this DDLElement will - be associated with. - - A DDLElement instance can be linked to any number of schema items. - - ``execute_at`` builds on the ``append_ddl_listener`` interface of - MetaDta and Table objects. - - Caveat: Creating or dropping a Table in isolation will also trigger - any DDL set to ``execute_at`` that Table's MetaData. This may change - in a future release. - """ - - if not hasattr(target, 'ddl_listeners'): - raise exc.ArgumentError( - "%s does not support DDL events" % type(target).__name__) - if event not in target.ddl_events: - raise exc.ArgumentError( - "Unknown event, expected one of (%s), got '%r'" % - (', '.join(target.ddl_events), event)) - target.ddl_listeners[event].append(self) - return self - - @expression._generative - def against(self, target): - """Return a copy of this DDL against a specific schema item.""" - - self.target = target - - def __call__(self, event, target, bind, **kw): - """Execute the DDL as a ddl_listener.""" - - if self._should_execute(event, target, bind, **kw): - return bind.execute(self.against(target)) - - def _check_ddl_on(self, on): - if (on is not None and - (not isinstance(on, (basestring, tuple, list, set)) and not util.callable(on))): - raise exc.ArgumentError( - "Expected the name of a database dialect, a tuple of names, or a callable for " - "'on' criteria, got type '%s'." % type(on).__name__) - - def _should_execute(self, event, target, bind, **kw): - if self.on is None: - return True - elif isinstance(self.on, basestring): - return self.on == bind.engine.name - elif isinstance(self.on, (tuple, list, set)): - return bind.engine.name in self.on - else: - return self.on(self, event, target, bind, **kw) - - def bind(self): - if self._bind: - return self._bind - def _set_bind(self, bind): - self._bind = bind - bind = property(bind, _set_bind) - - def _generate(self): - s = self.__class__.__new__(self.__class__) - s.__dict__ = self.__dict__.copy() - return s - - def _compiler(self, dialect, **kw): - """Return a compiler appropriate for this ClauseElement, given a Dialect.""" - - return dialect.ddl_compiler(dialect, self, **kw) - -class DDL(DDLElement): - """A literal DDL statement. - - Specifies literal SQL DDL to be executed by the database. DDL objects can - be attached to ``Tables`` or ``MetaData`` instances, conditionally - executing SQL as part of the DDL lifecycle of those schema items. Basic - templating support allows a single DDL instance to handle repetitive tasks - for multiple tables. - - Examples:: - - tbl = Table('users', metadata, Column('uid', Integer)) # ... - DDL('DROP TRIGGER users_trigger').execute_at('before-create', tbl) - - spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE', on='somedb') - spow.execute_at('after-create', tbl) - - drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE') - connection.execute(drop_spow) - - When operating on Table events, the following ``statement`` - string substitions are available:: - - %(table)s - the Table name, with any required quoting applied - %(schema)s - the schema name, with any required quoting applied - %(fullname)s - the Table name including schema, quoted if needed - - The DDL's ``context``, if any, will be combined with the standard - substutions noted above. Keys present in the context will override - the standard substitutions. - - """ - - __visit_name__ = "ddl" - - def __init__(self, statement, on=None, context=None, bind=None): - """Create a DDL statement. - - statement - A string or unicode string to be executed. Statements will be - processed with Python's string formatting operator. See the - ``context`` argument and the ``execute_at`` method. - - A literal '%' in a statement must be escaped as '%%'. - - SQL bind parameters are not available in DDL statements. - - on - Optional filtering criteria. May be a string, tuple or a callable - predicate. If a string, it will be compared to the name of the - executing database dialect:: - - DDL('something', on='postgresql') - - If a tuple, specifies multiple dialect names:: - - DDL('something', on=('postgresql', 'mysql')) - - If a callable, it will be invoked with four positional arguments - as well as optional keyword arguments: - - ddl - This DDL element. - - event - The name of the event that has triggered this DDL, such as - 'after-create' Will be None if the DDL is executed explicitly. - - target - The ``Table`` or ``MetaData`` object which is the target of - this event. May be None if the DDL is executed explicitly. - - connection - The ``Connection`` being used for DDL execution - - \**kw - Keyword arguments which may be sent include: - tables - a list of Table objects which are to be created/ - dropped within a MetaData.create_all() or drop_all() method - call. - - If the callable returns a true value, the DDL statement will be - executed. - - context - Optional dictionary, defaults to None. These values will be - available for use in string substitutions on the DDL statement. - - bind - Optional. A :class:`~sqlalchemy.engine.base.Connectable`, used by default when ``execute()`` - is invoked without a bind argument. - - """ - - if not isinstance(statement, basestring): - raise exc.ArgumentError( - "Expected a string or unicode SQL statement, got '%r'" % - statement) - - self.statement = statement - self.context = context or {} - - self._check_ddl_on(on) - self.on = on - self._bind = bind - - - def __repr__(self): - return '<%s@%s; %s>' % ( - type(self).__name__, id(self), - ', '.join([repr(self.statement)] + - ['%s=%r' % (key, getattr(self, key)) - for key in ('on', 'context') - if getattr(self, key)])) - -def _to_schema_column(element): - if hasattr(element, '__clause_element__'): - element = element.__clause_element__() - if not isinstance(element, Column): - raise exc.ArgumentError("schema.Column object expected") - return element - -def _to_schema_column_or_string(element): - if hasattr(element, '__clause_element__'): - element = element.__clause_element__() - return element - -class _CreateDropBase(DDLElement): - """Base class for DDL constucts that represent CREATE and DROP or equivalents. - - The common theme of _CreateDropBase is a single - ``element`` attribute which refers to the element - to be created or dropped. - - """ - - def __init__(self, element, on=None, bind=None): - self.element = element - self._check_ddl_on(on) - self.on = on - self.bind = bind - - def _create_rule_disable(self, compiler): - """Allow disable of _create_rule using a callable. - - Pass to _create_rule using - util.portable_instancemethod(self._create_rule_disable) - to retain serializability. - - """ - return False - -class CreateTable(_CreateDropBase): - """Represent a CREATE TABLE statement.""" - - __visit_name__ = "create_table" - -class DropTable(_CreateDropBase): - """Represent a DROP TABLE statement.""" - - __visit_name__ = "drop_table" - -class CreateSequence(_CreateDropBase): - """Represent a CREATE SEQUENCE statement.""" - - __visit_name__ = "create_sequence" - -class DropSequence(_CreateDropBase): - """Represent a DROP SEQUENCE statement.""" - - __visit_name__ = "drop_sequence" - -class CreateIndex(_CreateDropBase): - """Represent a CREATE INDEX statement.""" - - __visit_name__ = "create_index" - -class DropIndex(_CreateDropBase): - """Represent a DROP INDEX statement.""" - - __visit_name__ = "drop_index" - -class AddConstraint(_CreateDropBase): - """Represent an ALTER TABLE ADD CONSTRAINT statement.""" - - __visit_name__ = "add_constraint" - - def __init__(self, element, *args, **kw): - super(AddConstraint, self).__init__(element, *args, **kw) - element._create_rule = util.portable_instancemethod(self._create_rule_disable) - -class DropConstraint(_CreateDropBase): - """Represent an ALTER TABLE DROP CONSTRAINT statement.""" - - __visit_name__ = "drop_constraint" - - def __init__(self, element, cascade=False, **kw): - self.cascade = cascade - super(DropConstraint, self).__init__(element, **kw) - element._create_rule = util.portable_instancemethod(self._create_rule_disable) - -def _bind_or_error(schemaitem, msg=None): - bind = schemaitem.bind - if not bind: - name = schemaitem.__class__.__name__ - label = getattr(schemaitem, 'fullname', - getattr(schemaitem, 'name', None)) - if label: - item = '%s %r' % (name, label) - else: - item = name - if isinstance(schemaitem, (MetaData, DDL)): - bindable = "the %s's .bind" % name - else: - bindable = "this %s's .metadata.bind" % name - - if msg is None: - msg = ('The %s is not bound to an Engine or Connection. ' - 'Execution can not proceed without a database to execute ' - 'against. Either execute with an explicit connection or ' - 'assign %s to enable implicit execution.') % (item, bindable) - raise exc.UnboundExecutionError(msg) - return bind +from .sql.base import ( + SchemaVisitor + ) + + +from .sql.schema import ( + BLANK_SCHEMA, + CheckConstraint, + Column, + ColumnDefault, + Constraint, + DefaultClause, + DefaultGenerator, + FetchedValue, + ForeignKey, + ForeignKeyConstraint, + Index, + MetaData, + PassiveDefault, + PrimaryKeyConstraint, + SchemaItem, + Sequence, + Table, + ThreadLocalMetaData, + UniqueConstraint, + _get_table_key, + ColumnCollectionConstraint, + ColumnCollectionMixin + ) + + +from .sql.naming import conv + + +from .sql.ddl import ( + DDL, + CreateTable, + DropTable, + CreateSequence, + DropSequence, + CreateIndex, + DropIndex, + CreateSchema, + DropSchema, + _DropView, + CreateColumn, + AddConstraint, + DropConstraint, + DDLBase, + DDLElement, + _CreateDropBase, + _DDLCompiles, + sort_tables, + sort_tables_and_constraints +) diff --git a/sqlalchemy/sql/__init__.py b/sqlalchemy/sql/__init__.py index aa18eac..5eebd7d 100644 --- a/sqlalchemy/sql/__init__.py +++ b/sqlalchemy/sql/__init__.py @@ -1,4 +1,11 @@ -from sqlalchemy.sql.expression import ( +# sql/__init__.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .expression import ( Alias, ClauseElement, ColumnCollection, @@ -11,9 +18,12 @@ from sqlalchemy.sql.expression import ( Select, Selectable, TableClause, + TableSample, Update, alias, and_, + any_, + all_, asc, between, bindparam, @@ -28,12 +38,16 @@ from sqlalchemy.sql.expression import ( except_all, exists, extract, + false, + False_, func, + funcfilter, insert, intersect, intersect_all, join, label, + lateral, literal, literal_column, modifier, @@ -42,17 +56,43 @@ from sqlalchemy.sql.expression import ( or_, outerjoin, outparam, + over, select, subquery, table, + tablesample, text, + true, + True_, tuple_, + type_coerce, union, union_all, update, - ) + within_group +) -from sqlalchemy.sql.visitors import ClauseVisitor +from .visitors import ClauseVisitor -__tmp = locals().keys() -__all__ = sorted([i for i in __tmp if not i.startswith('__')]) + +def __go(lcls): + global __all__ + from .. import util as _sa_util + + import inspect as _inspect + + __all__ = sorted(name for name, obj in lcls.items() + if not (name.startswith('_') or _inspect.ismodule(obj))) + + from .annotation import _prepare_annotations, Annotated + from .elements import AnnotatedColumnElement, ClauseList + from .selectable import AnnotatedFromClause + _prepare_annotations(ColumnElement, AnnotatedColumnElement) + _prepare_annotations(FromClause, AnnotatedFromClause) + _prepare_annotations(ClauseList, Annotated) + + _sa_util.dependencies.resolve_all("sqlalchemy.sql") + + from . import naming + +__go(locals()) diff --git a/sqlalchemy/sql/compiler.py b/sqlalchemy/sql/compiler.py index 78c6577..bfa22c2 100644 --- a/sqlalchemy/sql/compiler.py +++ b/sqlalchemy/sql/compiler.py @@ -1,5 +1,6 @@ -# compiler.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# sql/compiler.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -8,25 +9,26 @@ Classes provided include: -:class:`~sqlalchemy.sql.compiler.SQLCompiler` - renders SQL +:class:`.compiler.SQLCompiler` - renders SQL strings -:class:`~sqlalchemy.sql.compiler.DDLCompiler` - renders DDL +:class:`.compiler.DDLCompiler` - renders DDL (data definition language) strings -:class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders +:class:`.compiler.GenericTypeCompiler` - renders type specification strings. -To generate user-defined SQL strings, see -:module:`~sqlalchemy.ext.compiler`. +To generate user-defined SQL strings, see +:doc:`/ext/compiler`. """ +import contextlib import re -from sqlalchemy import schema, engine, util, exc -from sqlalchemy.sql import operators, functions, util as sql_util, visitors -from sqlalchemy.sql import expression as sql -import decimal +from . import schema, sqltypes, operators, functions, visitors, \ + elements, selectable, crud +from .. import util, exc +import itertools RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -48,63 +50,68 @@ RESERVED_WORDS = set([ 'using', 'verbose', 'when', 'where']) LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I) -ILLEGAL_INITIAL_CHARACTERS = set([str(x) for x in xrange(0, 10)]).union(['$']) +ILLEGAL_INITIAL_CHARACTERS = set([str(x) for x in range(0, 10)]).union(['$']) BIND_PARAMS = re.compile(r'(? ', - operators.ge : ' >= ', - operators.eq : ' = ', - operators.concat_op : ' || ', - operators.between_op : ' BETWEEN ', - operators.match_op : ' MATCH ', - operators.in_op : ' IN ', - operators.notin_op : ' NOT IN ', - operators.comma_op : ', ', - operators.from_ : ' FROM ', - operators.as_ : ' AS ', - operators.is_ : ' IS ', - operators.isnot : ' IS NOT ', - operators.collate : ' COLLATE ', + operators.and_: ' AND ', + operators.or_: ' OR ', + operators.add: ' + ', + operators.mul: ' * ', + operators.sub: ' - ', + operators.div: ' / ', + operators.mod: ' % ', + operators.truediv: ' / ', + operators.neg: '-', + operators.lt: ' < ', + operators.le: ' <= ', + operators.ne: ' != ', + operators.gt: ' > ', + operators.ge: ' >= ', + operators.eq: ' = ', + operators.is_distinct_from: ' IS DISTINCT FROM ', + operators.isnot_distinct_from: ' IS NOT DISTINCT FROM ', + operators.concat_op: ' || ', + operators.match_op: ' MATCH ', + operators.notmatch_op: ' NOT MATCH ', + operators.in_op: ' IN ', + operators.notin_op: ' NOT IN ', + operators.comma_op: ', ', + operators.from_: ' FROM ', + operators.as_: ' AS ', + operators.is_: ' IS ', + operators.isnot: ' IS NOT ', + operators.collate: ' COLLATE ', # unary - operators.exists : 'EXISTS ', - operators.distinct_op : 'DISTINCT ', - operators.inv : 'NOT ', + operators.exists: 'EXISTS ', + operators.distinct_op: 'DISTINCT ', + operators.inv: 'NOT ', + operators.any_op: 'ANY ', + operators.all_op: 'ALL ', # modifiers - operators.desc_op : ' DESC', - operators.asc_op : ' ASC', + operators.desc_op: ' DESC', + operators.asc_op: ' ASC', + operators.nullsfirst_op: ' NULLS FIRST', + operators.nullslast_op: ' NULLS LAST', + } FUNCTIONS = { - functions.coalesce : 'coalesce%(expr)s', + functions.coalesce: 'coalesce%(expr)s', functions.current_date: 'CURRENT_DATE', functions.current_time: 'CURRENT_TIME', functions.current_timestamp: 'CURRENT_TIMESTAMP', @@ -113,7 +120,7 @@ FUNCTIONS = { functions.localtimestamp: 'LOCALTIMESTAMP', functions.random: 'random%(expr)s', functions.sysdate: 'sysdate', - functions.session_user :'SESSION_USER', + functions.session_user: 'SESSION_USER', functions.user: 'USER' } @@ -136,116 +143,359 @@ EXTRACT_MAP = { } COMPOUND_KEYWORDS = { - sql.CompoundSelect.UNION : 'UNION', - sql.CompoundSelect.UNION_ALL : 'UNION ALL', - sql.CompoundSelect.EXCEPT : 'EXCEPT', - sql.CompoundSelect.EXCEPT_ALL : 'EXCEPT ALL', - sql.CompoundSelect.INTERSECT : 'INTERSECT', - sql.CompoundSelect.INTERSECT_ALL : 'INTERSECT ALL' + selectable.CompoundSelect.UNION: 'UNION', + selectable.CompoundSelect.UNION_ALL: 'UNION ALL', + selectable.CompoundSelect.EXCEPT: 'EXCEPT', + selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL', + selectable.CompoundSelect.INTERSECT: 'INTERSECT', + selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL' } + +class Compiled(object): + + """Represent a compiled SQL or DDL expression. + + The ``__str__`` method of the ``Compiled`` object should produce + the actual text of the statement. ``Compiled`` objects are + specific to their underlying database dialect, and also may + or may not be specific to the columns referenced within a + particular set of bind parameters. In no case should the + ``Compiled`` object be dependent on the actual values of those + bind parameters, even though it may reference those values as + defaults. + """ + + _cached_metadata = None + + execution_options = util.immutabledict() + """ + Execution options propagated from the statement. In some cases, + sub-elements of the statement can modify these. + """ + + def __init__(self, dialect, statement, bind=None, + schema_translate_map=None, + compile_kwargs=util.immutabledict()): + """Construct a new :class:`.Compiled` object. + + :param dialect: :class:`.Dialect` to compile against. + + :param statement: :class:`.ClauseElement` to be compiled. + + :param bind: Optional Engine or Connection to compile this + statement against. + + :param schema_translate_map: dictionary of schema names to be + translated when forming the resultant SQL + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`schema_translating` + + :param compile_kwargs: additional kwargs that will be + passed to the initial call to :meth:`.Compiled.process`. + + + """ + + self.dialect = dialect + self.bind = bind + self.preparer = self.dialect.identifier_preparer + if schema_translate_map: + self.preparer = self.preparer._with_schema_translate( + schema_translate_map) + + if statement is not None: + self.statement = statement + self.can_execute = statement.supports_execution + if self.can_execute: + self.execution_options = statement._execution_options + self.string = self.process(self.statement, **compile_kwargs) + + @util.deprecated("0.7", ":class:`.Compiled` objects now compile " + "within the constructor.") + def compile(self): + """Produce the internal string representation of this element. + """ + pass + + def _execute_on_connection(self, connection, multiparams, params): + if self.can_execute: + return connection._execute_compiled(self, multiparams, params) + else: + raise exc.ObjectNotExecutableError(self.statement) + + @property + def sql_compiler(self): + """Return a Compiled that is capable of processing SQL expressions. + + If this compiler is one, it would likely just return 'self'. + + """ + + raise NotImplementedError() + + def process(self, obj, **kwargs): + return obj._compiler_dispatch(self, **kwargs) + + def __str__(self): + """Return the string text of the generated SQL or DDL.""" + + return self.string or '' + + def construct_params(self, params=None): + """Return the bind params for this compiled object. + + :param params: a dict of string/object pairs whose values will + override bind values compiled in to the + statement. + """ + + raise NotImplementedError() + + @property + def params(self): + """Return the bind params for this compiled object.""" + return self.construct_params() + + def execute(self, *multiparams, **params): + """Execute this compiled object.""" + + e = self.bind + if e is None: + raise exc.UnboundExecutionError( + "This Compiled object is not bound to any Engine " + "or Connection.") + return e._execute_compiled(self, multiparams, params) + + def scalar(self, *multiparams, **params): + """Execute this compiled object and return the result's + scalar value.""" + + return self.execute(*multiparams, **params).scalar() + + +class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): + """Produces DDL specification for TypeEngine objects.""" + + ensure_kwarg = r'visit_\w+' + + def __init__(self, dialect): + self.dialect = dialect + + def process(self, type_, **kw): + return type_._compiler_dispatch(self, **kw) + + class _CompileLabel(visitors.Visitable): - """lightweight label object which acts as an expression._Label.""" + + """lightweight label object which acts as an expression.Label.""" __visit_name__ = 'label' __slots__ = 'element', 'name' - - def __init__(self, col, name): + + def __init__(self, col, name, alt_names=()): self.element = col self.name = name - + self._alt_names = (col,) + alt_names + @property - def quote(self): - return self.element.quote + def proxy_set(self): + return self.element.proxy_set -class SQLCompiler(engine.Compiled): - """Default implementation of Compiled. + @property + def type(self): + return self.element.type - Compiles ClauseElements into SQL strings. Uses a similar visit - paradigm as visitors.ClauseVisitor but implements its own traversal. + def self_group(self, **kw): + return self + + +class SQLCompiler(Compiled): + """Default implementation of :class:`.Compiled`. + + Compiles :class:`.ClauseElement` objects into SQL strings. """ extract_map = EXTRACT_MAP compound_keywords = COMPOUND_KEYWORDS - - # class-level defaults which can be set at the instance - # level to define if this Compiled instance represents - # INSERT/UPDATE/DELETE + isdelete = isinsert = isupdate = False - - # holds the "returning" collection of columns if - # the statement is CRUD and defines returning columns - # either implicitly or explicitly + """class-level defaults which can be set at the instance + level to define if this Compiled instance represents + INSERT/UPDATE/DELETE + """ + + isplaintext = False + returning = None - - # set to True classwide to generate RETURNING - # clauses before the VALUES or WHERE clause (i.e. MSSQL) + """holds the "returning" collection of columns if + the statement is CRUD and defines returning columns + either implicitly or explicitly + """ + returning_precedes_values = False - - # SQL 92 doesn't allow bind parameters to be used - # in the columns clause of a SELECT, nor does it allow - # ambiguous expressions like "? = ?". A compiler - # subclass can set this flag to False if the target - # driver/DB enforces this + """set to True classwide to generate RETURNING + clauses before the VALUES or WHERE clause (i.e. MSSQL) + """ + + render_table_with_column_in_update_from = False + """set to True classwide to indicate the SET clause + in a multi-table UPDATE statement should qualify + columns with the table name (i.e. MySQL only) + """ + ansi_bind_rules = False - - def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): - """Construct a new ``DefaultCompiler`` object. + """SQL 92 doesn't allow bind parameters to be used + in the columns clause of a SELECT, nor does it allow + ambiguous expressions like "? = ?". A compiler + subclass can set this flag to False if the target + driver/DB enforces this + """ - dialect - Dialect to be used + _textual_ordered_columns = False + """tell the result object that the column names as rendered are important, + but they are also "ordered" vs. what is in the compiled object here. + """ - statement - ClauseElement to be compiled + _ordered_columns = True + """ + if False, means we can't be sure the list of entries + in _result_columns is actually the rendered order. Usually + True unless using an unordered TextAsFrom. + """ - column_keys - a list of column names to be compiled into an INSERT or UPDATE - statement. + insert_prefetch = update_prefetch = () + + + def __init__(self, dialect, statement, column_keys=None, + inline=False, **kwargs): + """Construct a new :class:`.SQLCompiler` object. + + :param dialect: :class:`.Dialect` to be used + + :param statement: :class:`.ClauseElement` to be compiled + + :param column_keys: a list of column names to be compiled into an + INSERT or UPDATE statement. + + :param inline: whether to generate INSERT statements as "inline", e.g. + not formatted to return any generated defaults + + :param kwargs: additional keyword arguments to be consumed by the + superclass. """ - engine.Compiled.__init__(self, dialect, statement, **kwargs) - self.column_keys = column_keys - # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) + # compile INSERT/UPDATE defaults/sequences inlined (no pre- + # execute) self.inline = inline or getattr(statement, 'inline', False) - # a dictionary of bind parameter keys to _BindParamClause instances. + # a dictionary of bind parameter keys to BindParameter + # instances. self.binds = {} - # a dictionary of _BindParamClause instances to "compiled" names that are - # actually present in the generated SQL + # a dictionary of BindParameter instances to "compiled" names + # that are actually present in the generated SQL self.bind_names = util.column_dict() # stack which keeps track of nested SELECT statements self.stack = [] - # relates label names in the final SQL to - # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine. - # ResultProxy uses this for type processing and column targeting - self.result_map = {} + # relates label names in the final SQL to a tuple of local + # column/label name, ColumnElement object (if any) and + # TypeEngine. ResultProxy uses this for type processing and + # column targeting + self._result_columns = [] # true if the paramstyle is positional - self.positional = self.dialect.positional + self.positional = dialect.positional if self.positional: self.positiontup = [] + self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] - self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle] + self.ctes = None - # an IdentifierPreparer that formats the quoting of identifiers - self.preparer = self.dialect.identifier_preparer + self.label_length = dialect.label_length \ + or dialect.max_identifier_length - self.label_length = self.dialect.label_length or self.dialect.max_identifier_length - - # a map which tracks "anonymous" identifiers that are - # created on the fly here + # a map which tracks "anonymous" identifiers that are created on + # the fly here self.anon_map = util.PopulateDict(self._process_anon) - # a map which tracks "truncated" names based on dialect.label_length - # or dialect.max_identifier_length + # a map which tracks "truncated" names based on + # dialect.label_length or dialect.max_identifier_length self.truncated_names = {} + Compiled.__init__(self, dialect, statement, **kwargs) + + if ( + self.isinsert or self.isupdate or self.isdelete + ) and statement._returning: + self.returning = statement._returning + + if self.positional and dialect.paramstyle == 'numeric': + self._apply_numbered_params() + + @property + def prefetch(self): + return list(self.insert_prefetch + self.update_prefetch) + + @util.memoized_instancemethod + def _init_cte_state(self): + """Initialize collections related to CTEs only if + a CTE is located, to save on the overhead of + these collections otherwise. + + """ + # collect CTEs to tack on top of a SELECT + self.ctes = util.OrderedDict() + self.ctes_by_name = {} + self.ctes_recursive = False + if self.positional: + self.cte_positional = {} + + @contextlib.contextmanager + def _nested_result(self): + """special API to support the use case of 'nested result sets'""" + result_columns, ordered_columns = ( + self._result_columns, self._ordered_columns) + self._result_columns, self._ordered_columns = [], False + + try: + if self.stack: + entry = self.stack[-1] + entry['need_result_map_for_nested'] = True + else: + entry = None + yield self._result_columns, self._ordered_columns + finally: + if entry: + entry.pop('need_result_map_for_nested') + self._result_columns, self._ordered_columns = ( + result_columns, ordered_columns) + + def _apply_numbered_params(self): + poscount = itertools.count(1) + self.string = re.sub( + r'\[_POSITION\]', + lambda m: str(util.next(poscount)), + self.string) + + @util.memoized_property + def _bind_processors(self): + return dict( + (key, value) for key, value in + ((self.bind_names[bindparam], + bindparam.type._cached_bind_processor(self.dialect)) + for bindparam in self.bind_names) + if value is not None + ) def is_subquery(self): return len(self.stack) > 1 @@ -253,49 +503,69 @@ class SQLCompiler(engine.Compiled): @property def sql_compiler(self): return self - - def construct_params(self, params=None, _group_number=None): + + def construct_params(self, params=None, _group_number=None, _check=True): """return a dictionary of bind parameter keys and values""" if params: pd = {} - for bindparam, name in self.bind_names.iteritems(): - for paramname in (bindparam.key, name): - if paramname in params: - pd[name] = params[paramname] - break - else: - if bindparam.required: - if _group_number: - raise exc.InvalidRequestError( - "A value is required for bind parameter %r, " - "in parameter group %d" % - (bindparam.key, _group_number)) - else: - raise exc.InvalidRequestError( - "A value is required for bind parameter %r" - % bindparam.key) - elif util.callable(bindparam.value): - pd[name] = bindparam.value() + for bindparam in self.bind_names: + name = self.bind_names[bindparam] + if bindparam.key in params: + pd[name] = params[bindparam.key] + elif name in params: + pd[name] = params[name] + + elif _check and bindparam.required: + if _group_number: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r, " + "in parameter group %d" % + (bindparam.key, _group_number)) else: - pd[name] = bindparam.value + raise exc.InvalidRequestError( + "A value is required for bind parameter %r" + % bindparam.key) + + elif bindparam.callable: + pd[name] = bindparam.effective_value + else: + pd[name] = bindparam.value return pd else: pd = {} for bindparam in self.bind_names: - if util.callable(bindparam.value): - pd[self.bind_names[bindparam]] = bindparam.value() + if _check and bindparam.required: + if _group_number: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r, " + "in parameter group %d" % + (bindparam.key, _group_number)) + else: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r" + % bindparam.key) + + if bindparam.callable: + pd[self.bind_names[bindparam]] = bindparam.effective_value else: pd[self.bind_names[bindparam]] = bindparam.value return pd - params = property(construct_params, doc=""" - Return the bind params for this compiled object. + @property + def params(self): + """Return the bind param dictionary embedded into this + compiled object, for those values that are present.""" + return self.construct_params(_check=False) - """) + @util.dependencies("sqlalchemy.engine.result") + def _create_result_map(self, result): + """utility method used for unit tests only.""" + return result.ResultMetaData._create_result_map(self._result_columns) def default_from(self): - """Called when a SELECT statement has no froms, and no FROM clause is to be appended. + """Called when a SELECT statement has no froms, and no FROM clause is + to be appended. Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. @@ -303,61 +573,147 @@ class SQLCompiler(engine.Compiled): return "" def visit_grouping(self, grouping, asfrom=False, **kwargs): - return "(" + self.process(grouping.element, **kwargs) + ")" + return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" - def visit_label(self, label, result_map=None, - within_label_clause=False, - within_columns_clause=False, **kw): + def visit_label_reference( + self, element, within_columns_clause=False, **kwargs): + if self.stack and self.dialect.supports_simple_order_by_label: + selectable = self.stack[-1]['selectable'] + + with_cols, only_froms, only_cols = selectable._label_resolve_dict + if within_columns_clause: + resolve_dict = only_froms + else: + resolve_dict = only_cols + + # this can be None in the case that a _label_reference() + # were subject to a replacement operation, in which case + # the replacement of the Label element may have changed + # to something else like a ColumnClause expression. + order_by_elem = element.element._order_by_label_element + + if order_by_elem is not None and order_by_elem.name in \ + resolve_dict and \ + order_by_elem.shares_lineage( + resolve_dict[order_by_elem.name]): + kwargs['render_label_as_label'] = \ + element.element._order_by_label_element + return self.process( + element.element, within_columns_clause=within_columns_clause, + **kwargs) + + def visit_textual_label_reference( + self, element, within_columns_clause=False, **kwargs): + if not self.stack: + # compiling the element outside of the context of a SELECT + return self.process( + element._text_clause + ) + + selectable = self.stack[-1]['selectable'] + with_cols, only_froms, only_cols = selectable._label_resolve_dict + try: + if within_columns_clause: + col = only_froms[element.element] + else: + col = with_cols[element.element] + except KeyError: + # treat it like text() + util.warn_limited( + "Can't resolve label reference %r; converting to text()", + util.ellipses_string(element.element)) + return self.process( + element._text_clause + ) + else: + kwargs['render_label_as_label'] = col + return self.process( + col, within_columns_clause=within_columns_clause, **kwargs) + + def visit_label(self, label, + add_to_result_map=None, + within_label_clause=False, + within_columns_clause=False, + render_label_as_label=None, + **kw): # only render labels within the columns clause # or ORDER BY clause of a select. dialect-specific compilers # can modify this behavior. - if within_columns_clause and not within_label_clause: - labelname = isinstance(label.name, sql._generated_label) and \ - self._truncated_identifier("colident", label.name) or label.name + render_label_with_as = (within_columns_clause and not + within_label_clause) + render_label_only = render_label_as_label is label - if result_map is not None: - result_map[labelname.lower()] = \ - (label.name, (label, label.element, labelname), label.element.type) + if render_label_only or render_label_with_as: + if isinstance(label.name, elements._truncated_label): + labelname = self._truncated_identifier("colident", label.name) + else: + labelname = label.name - return self.process(label.element, - within_columns_clause=True, - within_label_clause=True, - **kw) + \ - OPERATORS[operators.as_] + \ - self.preparer.format_label(label, labelname) + if render_label_with_as: + if add_to_result_map is not None: + add_to_result_map( + labelname, + label.name, + (label, labelname, ) + label._alt_names, + label.type + ) + + return label.element._compiler_dispatch( + self, within_columns_clause=True, + within_label_clause=True, **kw) + \ + OPERATORS[operators.as_] + \ + self.preparer.format_label(label, labelname) + elif render_label_only: + return self.preparer.format_label(label, labelname) else: - return self.process(label.element, - within_columns_clause=False, - **kw) - - def visit_column(self, column, result_map=None, **kwargs): - name = column.name - if not column.is_literal and isinstance(name, sql._generated_label): + return label.element._compiler_dispatch( + self, within_columns_clause=False, **kw) + + def _fallback_column_name(self, column): + raise exc.CompileError("Cannot compile Column object until " + "its 'name' is assigned.") + + def visit_column(self, column, add_to_result_map=None, + include_table=True, **kwargs): + name = orig_name = column.name + if name is None: + name = self._fallback_column_name(column) + + is_literal = column.is_literal + if not is_literal and isinstance(name, elements._truncated_label): name = self._truncated_identifier("colident", name) - if result_map is not None: - result_map[name.lower()] = (name, (column, ), column.type) - - if column.is_literal: + if add_to_result_map is not None: + add_to_result_map( + name, + orig_name, + (column, name, column.key), + column.type + ) + + if is_literal: name = self.escape_literal_column(name) else: - name = self.preparer.quote(name, column.quote) + name = self.preparer.quote(name) - if column.table is None or not column.table.named_with_column: + table = column.table + if table is None or not include_table or not table.named_with_column: return name else: - if column.table.schema: + effective_schema = self.preparer.schema_for_object(table) + + if effective_schema: schema_prefix = self.preparer.quote_schema( - column.table.schema, - column.table.quote_schema) + '.' + effective_schema) + '.' else: schema_prefix = '' - tablename = column.table.name - tablename = isinstance(tablename, sql._generated_label) and \ - self._truncated_identifier("alias", tablename) or tablename - + tablename = table.name + if isinstance(tablename, elements._truncated_label): + tablename = self._truncated_identifier("alias", tablename) + return schema_prefix + \ - self.preparer.quote(tablename, column.table.quote) + "." + name + self.preparer.quote(tablename) + \ + "." + name def escape_literal_column(self, text): """provide escaping for the literal_column() construct.""" @@ -371,94 +727,228 @@ class SQLCompiler(engine.Compiled): def visit_index(self, index, **kwargs): return index.name - def visit_typeclause(self, typeclause, **kwargs): - return self.dialect.type_compiler.process(typeclause.type) + def visit_typeclause(self, typeclause, **kw): + kw['type_expression'] = typeclause + return self.dialect.type_compiler.process(typeclause.type, **kw) def post_process_text(self, text): return text - - def visit_textclause(self, textclause, **kwargs): - if textclause.typemap is not None: - for colname, type_ in textclause.typemap.iteritems(): - self.result_map[colname.lower()] = (colname, None, type_) + def visit_textclause(self, textclause, **kw): def do_bindparam(m): name = m.group(1) - if name in textclause.bindparams: - return self.process(textclause.bindparams[name]) + if name in textclause._bindparams: + return self.process(textclause._bindparams[name], **kw) else: - return self.bindparam_string(name) + return self.bindparam_string(name, **kw) + + if not self.stack: + self.isplaintext = True # un-escape any \:params - return BIND_PARAMS_ESC.sub(lambda m: m.group(1), - BIND_PARAMS.sub(do_bindparam, self.post_process_text(textclause.text)) + return BIND_PARAMS_ESC.sub( + lambda m: m.group(1), + BIND_PARAMS.sub( + do_bindparam, + self.post_process_text(textclause.text)) ) - def visit_null(self, null, **kwargs): + def visit_text_as_from(self, taf, + compound_index=None, + asfrom=False, + parens=True, **kw): + + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + populate_result_map = toplevel or \ + ( + compound_index == 0 and entry.get( + 'need_result_map_for_compound', False) + ) or entry.get('need_result_map_for_nested', False) + + if populate_result_map: + self._ordered_columns = \ + self._textual_ordered_columns = taf.positional + for c in taf.column_args: + self.process(c, within_columns_clause=True, + add_to_result_map=self._add_to_result_map) + + text = self.process(taf.element, **kw) + if asfrom and parens: + text = "(%s)" % text + return text + + def visit_null(self, expr, **kw): return 'NULL' - def visit_clauselist(self, clauselist, **kwargs): + def visit_true(self, expr, **kw): + if self.dialect.supports_native_boolean: + return 'true' + else: + return "1" + + def visit_false(self, expr, **kw): + if self.dialect.supports_native_boolean: + return 'false' + else: + return "0" + + def visit_clauselist(self, clauselist, **kw): sep = clauselist.operator if sep is None: sep = " " else: sep = OPERATORS[clauselist.operator] - return sep.join(s for s in (self.process(c, **kwargs) for c in clauselist.clauses) - if s is not None) + return sep.join( + s for s in + ( + c._compiler_dispatch(self, **kw) + for c in clauselist.clauses) + if s) def visit_case(self, clause, **kwargs): x = "CASE " if clause.value is not None: - x += self.process(clause.value, **kwargs) + " " + x += clause.value._compiler_dispatch(self, **kwargs) + " " for cond, result in clause.whens: - x += "WHEN " + self.process(cond, **kwargs) + \ - " THEN " + self.process(result, **kwargs) + " " + x += "WHEN " + cond._compiler_dispatch( + self, **kwargs + ) + " THEN " + result._compiler_dispatch( + self, **kwargs) + " " if clause.else_ is not None: - x += "ELSE " + self.process(clause.else_, **kwargs) + " " + x += "ELSE " + clause.else_._compiler_dispatch( + self, **kwargs + ) + " " x += "END" return x + def visit_type_coerce(self, type_coerce, **kw): + return type_coerce.typed_expression._compiler_dispatch(self, **kw) + def visit_cast(self, cast, **kwargs): return "CAST(%s AS %s)" % \ - (self.process(cast.clause, **kwargs), self.process(cast.typeclause, **kwargs)) + (cast.clause._compiler_dispatch(self, **kwargs), + cast.typeclause._compiler_dispatch(self, **kwargs)) + + def _format_frame_clause(self, range_, **kw): + return '%s AND %s' % ( + "UNBOUNDED PRECEDING" + if range_[0] is elements.RANGE_UNBOUNDED + else "CURRENT ROW" if range_[0] is elements.RANGE_CURRENT + else "%s PRECEDING" % (self.process(range_[0], **kw), ), + + "UNBOUNDED FOLLOWING" + if range_[1] is elements.RANGE_UNBOUNDED + else "CURRENT ROW" if range_[1] is elements.RANGE_CURRENT + else "%s FOLLOWING" % (self.process(range_[1], **kw), ) + ) + + def visit_over(self, over, **kwargs): + if over.range_: + range_ = "RANGE BETWEEN %s" % self._format_frame_clause( + over.range_, **kwargs) + elif over.rows: + range_ = "ROWS BETWEEN %s" % self._format_frame_clause( + over.rows, **kwargs) + else: + range_ = None + + return "%s OVER (%s)" % ( + over.element._compiler_dispatch(self, **kwargs), + ' '.join([ + '%s BY %s' % ( + word, clause._compiler_dispatch(self, **kwargs) + ) + for word, clause in ( + ('PARTITION', over.partition_by), + ('ORDER', over.order_by) + ) + if clause is not None and len(clause) + ] + ([range_] if range_ else []) + ) + ) + + def visit_withingroup(self, withingroup, **kwargs): + return "%s WITHIN GROUP (ORDER BY %s)" % ( + withingroup.element._compiler_dispatch(self, **kwargs), + withingroup.order_by._compiler_dispatch(self, **kwargs) + ) + + def visit_funcfilter(self, funcfilter, **kwargs): + return "%s FILTER (WHERE %s)" % ( + funcfilter.func._compiler_dispatch(self, **kwargs), + funcfilter.criterion._compiler_dispatch(self, **kwargs) + ) def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) - return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr, **kwargs)) + return "EXTRACT(%s FROM %s)" % ( + field, extract.expr._compiler_dispatch(self, **kwargs)) - def visit_function(self, func, result_map=None, **kwargs): - if result_map is not None: - result_map[func.name.lower()] = (func.name, None, func.type) + def visit_function(self, func, add_to_result_map=None, **kwargs): + if add_to_result_map is not None: + add_to_result_map( + func.name, func.name, (), func.type + ) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) if disp: return disp(func, **kwargs) else: name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s") - return ".".join(func.packagenames + [name]) % \ - {'expr':self.function_argspec(func, **kwargs)} + return ".".join(list(func.packagenames) + [name]) % \ + {'expr': self.function_argspec(func, **kwargs)} + + def visit_next_value_func(self, next_value, **kw): + return self.visit_sequence(next_value.sequence) + + def visit_sequence(self, sequence): + raise NotImplementedError( + "Dialect '%s' does not support sequence increments." % + self.dialect.name + ) def function_argspec(self, func, **kwargs): - return self.process(func.clause_expr, **kwargs) + return func.clause_expr._compiler_dispatch(self, **kwargs) - def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=1, **kwargs): - entry = self.stack and self.stack[-1] or {} - self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) + def visit_compound_select(self, cs, asfrom=False, + parens=True, compound_index=0, **kwargs): + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + need_result_map = toplevel or \ + (compound_index == 0 + and entry.get('need_result_map_for_compound', False)) + + self.stack.append( + { + 'correlate_froms': entry['correlate_froms'], + 'asfrom_froms': entry['asfrom_froms'], + 'selectable': cs, + 'need_result_map_for_compound': need_result_map + }) keyword = self.compound_keywords.get(cs.keyword) - + text = (" " + keyword + " ").join( - (self.process(c, asfrom=asfrom, parens=False, - compound_index=i, **kwargs) - for i, c in enumerate(cs.selects)) - ) - - group_by = self.process(cs._group_by_clause, asfrom=asfrom, **kwargs) + (c._compiler_dispatch(self, + asfrom=asfrom, parens=False, + compound_index=i, **kwargs) + for i, c in enumerate(cs.selects)) + ) + + group_by = cs._group_by_clause._compiler_dispatch( + self, asfrom=asfrom, **kwargs) if group_by: text += " GROUP BY " + group_by text += self.order_by_clause(cs, **kwargs) - text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or "" + text += (cs._limit_clause is not None + or cs._offset_clause is not None) and \ + self.limit_clause(cs, **kwargs) or "" + + if self.ctes and toplevel: + text = self._render_cte_clause() + text self.stack.pop(-1) if asfrom and parens: @@ -466,132 +956,288 @@ class SQLCompiler(engine.Compiled): else: return text - def visit_unary(self, unary, **kw): - s = self.process(unary.element, **kw) - if unary.operator: - s = OPERATORS[unary.operator] + s - if unary.modifier: - s = s + OPERATORS[unary.modifier] - return s + def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): + attrname = "visit_%s_%s%s" % ( + operator_.__name__, qualifier1, + "_" + qualifier2 if qualifier2 else "") + return getattr(self, attrname, None) + + def visit_unary(self, unary, **kw): + if unary.operator: + if unary.modifier: + raise exc.CompileError( + "Unary expression does not support operator " + "and modifier simultaneously") + disp = self._get_operator_dispatch( + unary.operator, "unary", "operator") + if disp: + return disp(unary, unary.operator, **kw) + else: + return self._generate_generic_unary_operator( + unary, OPERATORS[unary.operator], **kw) + elif unary.modifier: + disp = self._get_operator_dispatch( + unary.modifier, "unary", "modifier") + if disp: + return disp(unary, unary.modifier, **kw) + else: + return self._generate_generic_unary_modifier( + unary, OPERATORS[unary.modifier], **kw) + else: + raise exc.CompileError( + "Unary expression has no operator or modifier") + + def visit_istrue_unary_operator(self, element, operator, **kw): + if self.dialect.supports_native_boolean: + return self.process(element.element, **kw) + else: + return "%s = 1" % self.process(element.element, **kw) + + def visit_isfalse_unary_operator(self, element, operator, **kw): + if self.dialect.supports_native_boolean: + return "NOT %s" % self.process(element.element, **kw) + else: + return "%s = 0" % self.process(element.element, **kw) + + def visit_notmatch_op_binary(self, binary, operator, **kw): + return "NOT %s" % self.visit_binary( + binary, override_operator=operators.match_op) + + def visit_binary(self, binary, override_operator=None, + eager_grouping=False, **kw): - def visit_binary(self, binary, **kw): # don't allow "? = ?" to render if self.ansi_bind_rules and \ - isinstance(binary.left, sql._BindParamClause) and \ - isinstance(binary.right, sql._BindParamClause): + isinstance(binary.left, elements.BindParameter) and \ + isinstance(binary.right, elements.BindParameter): kw['literal_binds'] = True - - return self._operator_dispatch(binary.operator, - binary, - lambda opstr: self.process(binary.left, **kw) + - opstr + - self.process(binary.right, **kw), - **kw + + operator_ = override_operator or binary.operator + disp = self._get_operator_dispatch(operator_, "binary", None) + if disp: + return disp(binary, operator_, **kw) + else: + try: + opstring = OPERATORS[operator_] + except KeyError: + raise exc.UnsupportedCompilationError(self, operator_) + else: + return self._generate_generic_binary(binary, opstring, **kw) + + def visit_custom_op_binary(self, element, operator, **kw): + kw['eager_grouping'] = operator.eager_grouping + return self._generate_generic_binary( + element, " " + operator.opstring + " ", **kw) + + def visit_custom_op_unary_operator(self, element, operator, **kw): + return self._generate_generic_unary_operator( + element, operator.opstring + " ", **kw) + + def visit_custom_op_unary_modifier(self, element, operator, **kw): + return self._generate_generic_unary_modifier( + element, " " + operator.opstring, **kw) + + def _generate_generic_binary( + self, binary, opstring, eager_grouping=False, **kw): + + _in_binary = kw.get('_in_binary', False) + + kw['_in_binary'] = True + text = binary.left._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw) + \ + opstring + \ + binary.right._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw) + + if _in_binary and eager_grouping: + text = "(%s)" % text + return text + + def _generate_generic_unary_operator(self, unary, opstring, **kw): + return opstring + unary.element._compiler_dispatch(self, **kw) + + def _generate_generic_unary_modifier(self, unary, opstring, **kw): + return unary.element._compiler_dispatch(self, **kw) + opstring + + @util.memoized_property + def _like_percent_literal(self): + return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE) + + def visit_contains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.__add__(binary.right).__add__(percent) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_notcontains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.__add__(binary.right).__add__(percent) + return self.visit_notlike_op_binary(binary, operator, **kw) + + def visit_startswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.__radd__( + binary.right ) + return self.visit_like_op_binary(binary, operator, **kw) - def visit_like_op(self, binary, **kw): + def visit_notstartswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.__radd__( + binary.right + ) + return self.visit_notlike_op_binary(binary, operator, **kw) + + def visit_endswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.__add__(binary.right) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_notendswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.__add__(binary.right) + return self.visit_notlike_op_binary(binary, operator, **kw) + + def visit_like_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return '%s LIKE %s' % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw)) \ - + (escape and ' ESCAPE \'%s\'' % escape or '') - def visit_notlike_op(self, binary, **kw): + # TODO: use ternary here, not "and"/ "or" + return '%s LIKE %s' % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) \ + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) + + def visit_notlike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return '%s NOT LIKE %s' % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw)) \ - + (escape and ' ESCAPE \'%s\'' % escape or '') - - def visit_ilike_op(self, binary, **kw): + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) \ + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) + + def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) LIKE lower(%s)' % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw)) \ - + (escape and ' ESCAPE \'%s\'' % escape or '') - - def visit_notilike_op(self, binary, **kw): + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) \ + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) + + def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) NOT LIKE lower(%s)' % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw)) \ - + (escape and ' ESCAPE \'%s\'' % escape or '') - - def _operator_dispatch(self, operator, element, fn, **kw): - if util.callable(operator): - disp = getattr(self, "visit_%s" % operator.__name__, None) - if disp: - return disp(element, **kw) - else: - return fn(OPERATORS[operator]) - else: - return fn(" " + operator + " ") - - def visit_bindparam(self, bindparam, within_columns_clause=False, - literal_binds=False, **kwargs): + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) \ + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) + + def visit_between_op_binary(self, binary, operator, **kw): + symmetric = binary.modifiers.get("symmetric", False) + return self._generate_generic_binary( + binary, " BETWEEN SYMMETRIC " + if symmetric else " BETWEEN ", **kw) + + def visit_notbetween_op_binary(self, binary, operator, **kw): + symmetric = binary.modifiers.get("symmetric", False) + return self._generate_generic_binary( + binary, " NOT BETWEEN SYMMETRIC " + if symmetric else " NOT BETWEEN ", **kw) + + def visit_bindparam(self, bindparam, within_columns_clause=False, + literal_binds=False, + skip_bind_expression=False, + **kwargs): + if not skip_bind_expression and bindparam.type._has_bind_expression: + bind_expression = bindparam.type.bind_expression(bindparam) + return self.process(bind_expression, + skip_bind_expression=True) + if literal_binds or \ - (within_columns_clause and \ + (within_columns_clause and self.ansi_bind_rules): - if bindparam.value is None: - raise exc.CompileError("Bind parameter without a " - "renderable value not allowed here.") - return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs) - + if bindparam.value is None and bindparam.callable is None: + raise exc.CompileError("Bind parameter '%s' without a " + "renderable value not allowed here." + % bindparam.key) + return self.render_literal_bindparam( + bindparam, within_columns_clause=True, **kwargs) + name = self._truncate_bindparam(bindparam) + if name in self.binds: existing = self.binds[name] if existing is not bindparam: - if existing.unique or bindparam.unique: + if (existing.unique or bindparam.unique) and \ + not existing.proxy_set.intersection( + bindparam.proxy_set): raise exc.CompileError( - "Bind parameter '%s' conflicts with " - "unique bind parameter of the same name" % bindparam.key - ) - elif getattr(existing, '_is_crud', False): + "Bind parameter '%s' conflicts with " + "unique bind parameter of the same name" % + bindparam.key + ) + elif existing._is_crud or bindparam._is_crud: raise exc.CompileError( - "Bind parameter name '%s' is reserved " - "for the VALUES or SET clause of this insert/update statement." - % bindparam.key - ) - + "bindparam() name '%s' is reserved " + "for automatic usage in the VALUES or SET " + "clause of this " + "insert/update statement. Please use a " + "name other than column name when using bindparam() " + "with insert() or update() (for example, 'b_%s')." % + (bindparam.key, bindparam.key) + ) + self.binds[bindparam.key] = self.binds[name] = bindparam - return self.bindparam_string(name) - + + return self.bindparam_string(name, **kwargs) + def render_literal_bindparam(self, bindparam, **kw): - value = bindparam.value - processor = bindparam.bind_processor(self.dialect) - if processor: - value = processor(value) + value = bindparam.effective_value return self.render_literal_value(value, bindparam.type) - + def render_literal_value(self, value, type_): """Render the value of a bind parameter as a quoted literal. - - This is used for statement sections that do not accept bind paramters + + This is used for statement sections that do not accept bind parameters on the target driver/database. - + This should be implemented by subclasses using the quoting services of the DBAPI. - + """ - if isinstance(value, basestring): - value = value.replace("'", "''") - return "'%s'" % value - elif value is None: - return "NULL" - elif isinstance(value, (float, int, long)): - return repr(value) - elif isinstance(value, decimal.Decimal): - return str(value) + + processor = type_._cached_literal_processor(self.dialect) + if processor: + return processor(value) else: - raise NotImplementedError("Don't know how to literal-quote value %r" % value) - + raise NotImplementedError( + "Don't know how to literal-quote value %r" % value) + def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: return self.bind_names[bindparam] bind_name = bindparam.key - bind_name = isinstance(bind_name, sql._generated_label) and \ - self._truncated_identifier("bindparam", bind_name) or bind_name + if isinstance(bind_name, elements._truncated_label): + bind_name = self._truncated_identifier("bindparam", bind_name) + # add to bind_names for translation self.bind_names[bindparam] = bind_name @@ -601,502 +1247,1038 @@ class SQLCompiler(engine.Compiled): if (ident_class, name) in self.truncated_names: return self.truncated_names[(ident_class, name)] - anonname = name % self.anon_map + anonname = name.apply_map(self.anon_map) - if len(anonname) > self.label_length: + if len(anonname) > self.label_length - 6: counter = self.truncated_names.get(ident_class, 1) - truncname = anonname[0:max(self.label_length - 6, 0)] + "_" + hex(counter)[2:] + truncname = anonname[0:max(self.label_length - 6, 0)] + \ + "_" + hex(counter)[2:] self.truncated_names[ident_class] = counter + 1 else: truncname = anonname self.truncated_names[(ident_class, name)] = truncname return truncname - + def _anonymize(self, name): return name % self.anon_map - + def _process_anon(self, key): (ident, derived) = key.split(' ', 1) anonymous_counter = self.anon_map.get(derived, 1) self.anon_map[derived] = anonymous_counter + 1 return derived + "_" + str(anonymous_counter) - def bindparam_string(self, name): + def bindparam_string(self, name, positional_names=None, **kw): if self.positional: - self.positiontup.append(name) - return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} - else: - return self.bindtemplate % {'name':name} + if positional_names is not None: + positional_names.append(name) + else: + self.positiontup.append(name) + return self.bindtemplate % {'name': name} - def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs): + def visit_cte(self, cte, asfrom=False, ashint=False, + fromhints=None, + **kwargs): + self._init_cte_state() + + if isinstance(cte.name, elements._truncated_label): + cte_name = self._truncated_identifier("alias", cte.name) + else: + cte_name = cte.name + + if cte_name in self.ctes_by_name: + existing_cte = self.ctes_by_name[cte_name] + # we've generated a same-named CTE that we are enclosed in, + # or this is the same CTE. just return the name. + if cte in existing_cte._restates or cte is existing_cte: + return self.preparer.format_alias(cte, cte_name) + elif existing_cte in cte._restates: + # we've generated a same-named CTE that is + # enclosed in us - we take precedence, so + # discard the text for the "inner". + del self.ctes[existing_cte] + else: + raise exc.CompileError( + "Multiple, unrelated CTEs found with " + "the same name: %r" % + cte_name) + + self.ctes_by_name[cte_name] = cte + + # look for embedded DML ctes and propagate autocommit + if 'autocommit' in cte.element._execution_options and \ + 'autocommit' not in self.execution_options: + self.execution_options = self.execution_options.union( + {"autocommit": cte.element._execution_options['autocommit']}) + + if cte._cte_alias is not None: + orig_cte = cte._cte_alias + if orig_cte not in self.ctes: + self.visit_cte(orig_cte, **kwargs) + cte_alias_name = cte._cte_alias.name + if isinstance(cte_alias_name, elements._truncated_label): + cte_alias_name = self._truncated_identifier( + "alias", cte_alias_name) + else: + orig_cte = cte + cte_alias_name = None + if not cte_alias_name and cte not in self.ctes: + if cte.recursive: + self.ctes_recursive = True + text = self.preparer.format_alias(cte, cte_name) + if cte.recursive: + if isinstance(cte.original, selectable.Select): + col_source = cte.original + elif isinstance(cte.original, selectable.CompoundSelect): + col_source = cte.original.selects[0] + else: + assert False + recur_cols = [c for c in + util.unique_list(col_source.inner_columns) + if c is not None] + + text += "(%s)" % (", ".join( + self.preparer.format_column(ident) + for ident in recur_cols)) + + if self.positional: + kwargs['positional_names'] = self.cte_positional[cte] = [] + + text += " AS \n" + \ + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) + + if cte._suffixes: + text += " " + self._generate_prefixes( + cte, cte._suffixes, **kwargs) + + self.ctes[cte] = text + + if asfrom: + if cte_alias_name: + text = self.preparer.format_alias(cte, cte_alias_name) + text += self.get_render_as_alias_suffix(cte_name) + else: + return self.preparer.format_alias(cte, cte_name) + return text + + def visit_alias(self, alias, asfrom=False, ashint=False, + iscrud=False, + fromhints=None, **kwargs): if asfrom or ashint: - alias_name = isinstance(alias.name, sql._generated_label) and \ - self._truncated_identifier("alias", alias.name) or alias.name + if isinstance(alias.name, elements._truncated_label): + alias_name = self._truncated_identifier("alias", alias.name) + else: + alias_name = alias.name + if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: - ret = self.process(alias.original, asfrom=True, **kwargs) + " AS " + \ - self.preparer.format_alias(alias, alias_name) - + ret = alias.original._compiler_dispatch(self, + asfrom=True, **kwargs) + \ + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name)) + if fromhints and alias in fromhints: - hinttext = self.get_from_hint_text(alias, fromhints[alias]) - if hinttext: - ret += " " + hinttext - + ret = self.format_from_hint_text(ret, alias, + fromhints[alias], iscrud) + return ret else: - return self.process(alias.original, **kwargs) + return alias.original._compiler_dispatch(self, **kwargs) - def label_select_column(self, select, column, asfrom): - """label columns present in a select().""" + def visit_lateral(self, lateral, **kw): + kw['lateral'] = True + return "LATERAL %s" % self.visit_alias(lateral, **kw) - if isinstance(column, sql._Label): - return column + def visit_tablesample(self, tablesample, asfrom=False, **kw): + text = "%s TABLESAMPLE %s" % ( + self.visit_alias(tablesample, asfrom=True, **kw), + tablesample._get_method()._compiler_dispatch(self, **kw)) - if select is not None and select.use_labels and column._label: - return _CompileLabel(column, column._label) + if tablesample.seed is not None: + text += " REPEATABLE (%s)" % ( + tablesample.seed._compiler_dispatch(self, **kw)) - if \ + return text + + def get_render_as_alias_suffix(self, alias_name_text): + return " AS " + alias_name_text + + def _add_to_result_map(self, keyname, name, objects, type_): + self._result_columns.append((keyname, name, objects, type_)) + + def _label_select_column(self, select, column, + populate_result_map, + asfrom, column_clause_args, + name=None, + within_columns_clause=True): + """produce labeled columns present in a select().""" + + if column.type._has_column_expression and \ + populate_result_map: + col_expr = column.type.column_expression(column) + add_to_result_map = lambda keyname, name, objects, type_: \ + self._add_to_result_map( + keyname, name, + (column,) + objects, type_) + else: + col_expr = column + if populate_result_map: + add_to_result_map = self._add_to_result_map + else: + add_to_result_map = None + + if not within_columns_clause: + result_expr = col_expr + elif isinstance(column, elements.Label): + if col_expr is not column: + result_expr = _CompileLabel( + col_expr, + column.name, + alt_names=(column.element,) + ) + else: + result_expr = col_expr + + elif select is not None and name: + result_expr = _CompileLabel( + col_expr, + name, + alt_names=(column._key_label,) + ) + + elif \ asfrom and \ - isinstance(column, sql.ColumnClause) and \ + isinstance(column, elements.ColumnClause) and \ not column.is_literal and \ column.table is not None and \ - not isinstance(column.table, sql.Select): - return _CompileLabel(column, sql._generated_label(column.name)) - elif not isinstance(column, - (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \ - and (not hasattr(column, 'name') or isinstance(column, sql.Function)): - return _CompileLabel(column, column.anon_label) + not isinstance(column.table, selectable.Select): + result_expr = _CompileLabel(col_expr, + elements._as_truncated(column.name), + alt_names=(column.key,)) + elif ( + not isinstance(column, elements.TextClause) and + ( + not isinstance(column, elements.UnaryExpression) or + column.wraps_column_expression + ) and + ( + not hasattr(column, 'name') or + isinstance(column, functions.Function) + ) + ): + result_expr = _CompileLabel(col_expr, column.anon_label) + elif col_expr is not column: + # TODO: are we sure "column" has a .name and .key here ? + # assert isinstance(column, elements.ColumnClause) + result_expr = _CompileLabel(col_expr, + elements._as_truncated(column.name), + alt_names=(column.key,)) else: - return column + result_expr = col_expr + + column_clause_args.update( + within_columns_clause=within_columns_clause, + add_to_result_map=add_to_result_map + ) + return result_expr._compiler_dispatch( + self, + **column_clause_args + ) + + def format_from_hint_text(self, sqltext, table, hint, iscrud): + hinttext = self.get_from_hint_text(table, hint) + if hinttext: + sqltext += " " + hinttext + return sqltext def get_select_hint_text(self, byfroms): return None - + def get_from_hint_text(self, table, text): return None - - def visit_select(self, select, asfrom=False, parens=True, - iswrapper=False, fromhints=None, - compound_index=1, **kwargs): - entry = self.stack and self.stack[-1] or {} - - existingfroms = entry.get('from', None) + def get_crud_hint_text(self, table, text): + return None - froms = select._get_display_froms(existingfroms) + def get_statement_hint_text(self, hint_texts): + return " ".join(hint_texts) - correlate_froms = set(sql._from_objects(*froms)) + def _transform_select_for_nested_joins(self, select): + """Rewrite any "a JOIN (b JOIN c)" expression as + "a JOIN (select * from b JOIN c) AS anon", to support + databases that can't parse a parenthesized join correctly + (i.e. sqlite < 3.7.16). - # TODO: might want to propagate existing froms for select(select(select)) - # where innermost select should correlate to outermost - # if existingfroms: - # correlate_froms = correlate_froms.union(existingfroms) + """ + cloned = {} + column_translate = [{}] - self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper}) + def visit(element, **kw): + if element in column_translate[-1]: + return column_translate[-1][element] - if compound_index==1 and not entry or entry.get('iswrapper', False): - column_clause_args = {'result_map':self.result_map} - else: - column_clause_args = {} + elif element in cloned: + return cloned[element] - # the actual list of columns to print in the SELECT column list. - inner_columns = [ - c for c in [ - self.process( - self.label_select_column(select, co, asfrom=asfrom), - within_columns_clause=True, - **column_clause_args) - for co in util.unique_list(select.inner_columns) - ] - if c is not None + newelem = cloned[element] = element._clone() + + if newelem.is_selectable and newelem._is_join and \ + isinstance(newelem.right, selectable.FromGrouping): + + newelem._reset_exported() + newelem.left = visit(newelem.left, **kw) + + right = visit(newelem.right, **kw) + + selectable_ = selectable.Select( + [right.element], + use_labels=True).alias() + + for c in selectable_.c: + c._key_label = c.key + c._label = c.name + + translate_dict = dict( + zip(newelem.right.element.c, selectable_.c) + ) + + # translating from both the old and the new + # because different select() structures will lead us + # to traverse differently + translate_dict[right.element.left] = selectable_ + translate_dict[right.element.right] = selectable_ + translate_dict[newelem.right.element.left] = selectable_ + translate_dict[newelem.right.element.right] = selectable_ + + # propagate translations that we've gained + # from nested visit(newelem.right) outwards + # to the enclosing select here. this happens + # only when we have more than one level of right + # join nesting, i.e. "a JOIN (b JOIN (c JOIN d))" + for k, v in list(column_translate[-1].items()): + if v in translate_dict: + # remarkably, no current ORM tests (May 2013) + # hit this condition, only test_join_rewriting + # does. + column_translate[-1][k] = translate_dict[v] + + column_translate[-1].update(translate_dict) + + newelem.right = selectable_ + + newelem.onclause = visit(newelem.onclause, **kw) + + elif newelem._is_from_container: + # if we hit an Alias, CompoundSelect or ScalarSelect, put a + # marker in the stack. + kw['transform_clue'] = 'select_container' + newelem._copy_internals(clone=visit, **kw) + elif newelem.is_selectable and newelem._is_select: + barrier_select = kw.get('transform_clue', None) == \ + 'select_container' + # if we're still descended from an + # Alias/CompoundSelect/ScalarSelect, we're + # in a FROM clause, so start with a new translate collection + if barrier_select: + column_translate.append({}) + kw['transform_clue'] = 'inside_select' + newelem._copy_internals(clone=visit, **kw) + if barrier_select: + del column_translate[-1] + else: + newelem._copy_internals(clone=visit, **kw) + + return newelem + + return visit(select) + + def _transform_result_map_for_nested_joins( + self, select, transformed_select): + inner_col = dict((c._key_label, c) for + c in transformed_select.inner_columns) + + d = dict( + (inner_col[c._key_label], c) + for c in select.inner_columns + ) + + self._result_columns = [ + (key, name, tuple([d.get(col, col) for col in objs]), typ) + for key, name, objs, typ in self._result_columns ] - + + _default_stack_entry = util.immutabledict([ + ('correlate_froms', frozenset()), + ('asfrom_froms', frozenset()) + ]) + + def _display_froms_for_select(self, select, asfrom, lateral=False): + # utility method to help external dialects + # get the correct from list for a select. + # specifically the oracle dialect needs this feature + # right now. + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + correlate_froms = entry['correlate_froms'] + asfrom_froms = entry['asfrom_froms'] + + if asfrom and not lateral: + froms = select._get_display_froms( + explicit_correlate_froms=correlate_froms.difference( + asfrom_froms), + implicit_correlate_froms=()) + else: + froms = select._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms) + return froms + + def visit_select(self, select, asfrom=False, parens=True, + fromhints=None, + compound_index=0, + nested_join_translation=False, + select_wraps_for=None, + lateral=False, + **kwargs): + + needs_nested_translation = \ + select.use_labels and \ + not nested_join_translation and \ + not self.stack and \ + not self.dialect.supports_right_nested_joins + + if needs_nested_translation: + transformed_select = self._transform_select_for_nested_joins( + select) + text = self.visit_select( + transformed_select, asfrom=asfrom, parens=parens, + fromhints=fromhints, + compound_index=compound_index, + nested_join_translation=True, **kwargs + ) + + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + populate_result_map = toplevel or \ + ( + compound_index == 0 and entry.get( + 'need_result_map_for_compound', False) + ) or entry.get('need_result_map_for_nested', False) + + # this was first proposed as part of #3372; however, it is not + # reached in current tests and could possibly be an assertion + # instead. + if not populate_result_map and 'add_to_result_map' in kwargs: + del kwargs['add_to_result_map'] + + if needs_nested_translation: + if populate_result_map: + self._transform_result_map_for_nested_joins( + select, transformed_select) + return text + + froms = self._setup_select_stack(select, entry, asfrom, lateral) + + column_clause_args = kwargs.copy() + column_clause_args.update({ + 'within_label_clause': False, + 'within_columns_clause': False + }) + text = "SELECT " # we're off to a good start ! if select._hints: - byfrom = dict([ - (from_, hinttext % {'name':self.process(from_, ashint=True)}) - for (from_, dialect), hinttext in - select._hints.iteritems() - if dialect in ('*', self.dialect.name) - ]) - hint_text = self.get_select_hint_text(byfrom) + hint_text, byfrom = self._setup_select_hints(select) if hint_text: text += hint_text + " " - + else: + byfrom = None + if select._prefixes: - text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " " - text += self.get_select_precolumns(select) + text += self._generate_prefixes( + select, select._prefixes, **kwargs) + + text += self.get_select_precolumns(select, **kwargs) + # the actual list of columns to print in the SELECT column list. + inner_columns = [ + c for c in [ + self._label_select_column( + select, + column, + populate_result_map, asfrom, + column_clause_args, + name=name) + for name, column in select._columns_plus_names + ] + if c is not None + ] + + if populate_result_map and select_wraps_for is not None: + # if this select is a compiler-generated wrapper, + # rewrite the targeted columns in the result map + + translate = dict( + zip( + [name for (key, name) in select._columns_plus_names], + [name for (key, name) in + select_wraps_for._columns_plus_names]) + ) + + self._result_columns = [ + (key, name, tuple(translate.get(o, o) for o in obj), type_) + for key, name, obj, type_ in self._result_columns + ] + + text = self._compose_select_body( + text, select, inner_columns, froms, byfrom, kwargs) + + if select._statement_hints: + per_dialect = [ + ht for (dialect_name, ht) + in select._statement_hints + if dialect_name in ('*', self.dialect.name) + ] + if per_dialect: + text += " " + self.get_statement_hint_text(per_dialect) + + if self.ctes and toplevel: + text = self._render_cte_clause() + text + + if select._suffixes: + text += " " + self._generate_prefixes( + select, select._suffixes, **kwargs) + + self.stack.pop(-1) + + if (asfrom or lateral) and parens: + return "(" + text + ")" + else: + return text + + def _setup_select_hints(self, select): + byfrom = dict([ + (from_, hinttext % { + 'name': from_._compiler_dispatch( + self, ashint=True) + }) + for (from_, dialect), hinttext in + select._hints.items() + if dialect in ('*', self.dialect.name) + ]) + hint_text = self.get_select_hint_text(byfrom) + return hint_text, byfrom + + def _setup_select_stack(self, select, entry, asfrom, lateral): + correlate_froms = entry['correlate_froms'] + asfrom_froms = entry['asfrom_froms'] + + if asfrom and not lateral: + froms = select._get_display_froms( + explicit_correlate_froms=correlate_froms.difference( + asfrom_froms), + implicit_correlate_froms=()) + else: + froms = select._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms) + + new_correlate_froms = set(selectable._from_objects(*froms)) + all_correlate_froms = new_correlate_froms.union(correlate_froms) + + new_entry = { + 'asfrom_froms': new_correlate_froms, + 'correlate_froms': all_correlate_froms, + 'selectable': select, + } + self.stack.append(new_entry) + + return froms + + def _compose_select_body( + self, text, select, inner_columns, froms, byfrom, kwargs): text += ', '.join(inner_columns) if froms: text += " \nFROM " - + if select._hints: - text += ', '.join([self.process(f, - asfrom=True, fromhints=byfrom, - **kwargs) - for f in froms]) + text += ', '.join( + [f._compiler_dispatch(self, asfrom=True, + fromhints=byfrom, **kwargs) + for f in froms]) else: - text += ', '.join([self.process(f, - asfrom=True, **kwargs) - for f in froms]) + text += ', '.join( + [f._compiler_dispatch(self, asfrom=True, **kwargs) + for f in froms]) else: text += self.default_from() if select._whereclause is not None: - t = self.process(select._whereclause, **kwargs) + t = select._whereclause._compiler_dispatch(self, **kwargs) if t: text += " \nWHERE " + t if select._group_by_clause.clauses: - group_by = self.process(select._group_by_clause, **kwargs) + group_by = select._group_by_clause._compiler_dispatch( + self, **kwargs) if group_by: text += " GROUP BY " + group_by if select._having is not None: - t = self.process(select._having, **kwargs) + t = select._having._compiler_dispatch(self, **kwargs) if t: text += " \nHAVING " + t if select._order_by_clause.clauses: text += self.order_by_clause(select, **kwargs) - if select._limit is not None or select._offset is not None: - text += self.limit_clause(select) - if select.for_update: - text += self.for_update_clause(select) - self.stack.pop(-1) + if (select._limit_clause is not None or + select._offset_clause is not None): + text += self.limit_clause(select, **kwargs) - if asfrom and parens: - return "(" + text + ")" + if select._for_update_arg is not None: + text += self.for_update_clause(select, **kwargs) + + return text + + def _generate_prefixes(self, stmt, prefixes, **kw): + clause = " ".join( + prefix._compiler_dispatch(self, **kw) + for prefix, dialect_name in prefixes + if dialect_name is None or + dialect_name == self.dialect.name + ) + if clause: + clause += " " + return clause + + def _render_cte_clause(self): + if self.positional: + self.positiontup = sum([ + self.cte_positional[cte] + for cte in self.ctes], []) + \ + self.positiontup + cte_text = self.get_cte_preamble(self.ctes_recursive) + " " + cte_text += ", \n".join( + [txt for txt in self.ctes.values()] + ) + cte_text += "\n " + return cte_text + + def get_cte_preamble(self, recursive): + if recursive: + return "WITH RECURSIVE" else: - return text + return "WITH" + + def get_select_precolumns(self, select, **kw): + """Called when building a ``SELECT`` statement, position is just + before column list. - def get_select_precolumns(self, select): - """Called when building a ``SELECT`` statement, position is just before - column list. - """ return select._distinct and "DISTINCT " or "" def order_by_clause(self, select, **kw): - order_by = self.process(select._order_by_clause, **kw) + order_by = select._order_by_clause._compiler_dispatch(self, **kw) if order_by: return " ORDER BY " + order_by else: return "" - def for_update_clause(self, select): - if select.for_update: - return " FOR UPDATE" - else: - return "" + def for_update_clause(self, select, **kw): + return " FOR UPDATE" - def limit_clause(self, select): + def returning_clause(self, stmt, returning_cols): + raise exc.CompileError( + "RETURNING is not supported by this " + "dialect's statement compiler.") + + def limit_clause(self, select, **kw): text = "" - if select._limit is not None: - text += " \n LIMIT " + str(select._limit) - if select._offset is not None: - if select._limit is None: - text += " \n LIMIT -1" - text += " OFFSET " + str(select._offset) + if select._limit_clause is not None: + text += "\n LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + if select._limit_clause is None: + text += "\n LIMIT -1" + text += " OFFSET " + self.process(select._offset_clause, **kw) return text - def visit_table(self, table, asfrom=False, ashint=False, fromhints=None, **kwargs): + def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, + fromhints=None, use_schema=True, **kwargs): if asfrom or ashint: - if getattr(table, "schema", None): - ret = self.preparer.quote_schema(table.schema, table.quote_schema) + \ - "." + self.preparer.quote(table.name, table.quote) + effective_schema = self.preparer.schema_for_object(table) + + if use_schema and effective_schema: + ret = self.preparer.quote_schema(effective_schema) + \ + "." + self.preparer.quote(table.name) else: - ret = self.preparer.quote(table.name, table.quote) + ret = self.preparer.quote(table.name) if fromhints and table in fromhints: - hinttext = self.get_from_hint_text(table, fromhints[table]) - if hinttext: - ret += " " + hinttext + ret = self.format_from_hint_text(ret, table, + fromhints[table], iscrud) return ret else: return "" def visit_join(self, join, asfrom=False, **kwargs): - return (self.process(join.left, asfrom=True, **kwargs) + \ - (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ - self.process(join.right, asfrom=True, **kwargs) + " ON " + \ - self.process(join.onclause, **kwargs)) + if join.full: + join_type = " FULL OUTER JOIN " + elif join.isouter: + join_type = " LEFT OUTER JOIN " + else: + join_type = " JOIN " + return ( + join.left._compiler_dispatch(self, asfrom=True, **kwargs) + + join_type + + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + + " ON " + + join.onclause._compiler_dispatch(self, **kwargs) + ) - def visit_sequence(self, seq): - return None + def _setup_crud_hints(self, stmt, table_text): + dialect_hints = dict([ + (table, hint_text) + for (table, dialect), hint_text in + stmt._hints.items() + if dialect in ('*', self.dialect.name) + ]) + if stmt.table in dialect_hints: + table_text = self.format_from_hint_text( + table_text, + stmt.table, + dialect_hints[stmt.table], + True + ) + return dialect_hints, table_text - def visit_insert(self, insert_stmt): - self.isinsert = True - colparams = self._get_colparams(insert_stmt) + def visit_insert(self, insert_stmt, asfrom=False, **kw): + toplevel = not self.stack - if not colparams and \ + self.stack.append( + {'correlate_froms': set(), + "asfrom_froms": set(), + "selectable": insert_stmt}) + + crud_params = crud._setup_crud_params( + self, insert_stmt, crud.ISINSERT, **kw) + + if not crud_params and \ not self.dialect.supports_default_values and \ not self.dialect.supports_empty_insert: - raise exc.CompileError("The version of %s you are using does " - "not support empty inserts." % - self.dialect.name) + raise exc.CompileError("The '%s' dialect with current database " + "version settings does not support empty " + "inserts." % + self.dialect.name) + + if insert_stmt._has_multi_parameters: + if not self.dialect.supports_multivalues_insert: + raise exc.CompileError( + "The '%s' dialect with current database " + "version settings does not support " + "in-place multirow inserts." % + self.dialect.name) + crud_params_single = crud_params[0] + else: + crud_params_single = crud_params preparer = self.preparer supports_default_values = self.dialect.supports_default_values - - text = "INSERT" - - prefixes = [self.process(x) for x in insert_stmt._prefixes] - if prefixes: - text += " " + " ".join(prefixes) - - text += " INTO " + preparer.format_table(insert_stmt.table) - - if colparams or not supports_default_values: + + text = "INSERT " + + if insert_stmt._prefixes: + text += self._generate_prefixes(insert_stmt, + insert_stmt._prefixes, **kw) + + text += "INTO " + table_text = preparer.format_table(insert_stmt.table) + + if insert_stmt._hints: + dialect_hints, table_text = self._setup_crud_hints( + insert_stmt, table_text) + else: + dialect_hints = None + + text += table_text + + if crud_params_single or not supports_default_values: text += " (%s)" % ', '.join([preparer.format_column(c[0]) - for c in colparams]) + for c in crud_params_single]) if self.returning or insert_stmt._returning: - self.returning = self.returning or insert_stmt._returning - returning_clause = self.returning_clause(insert_stmt, self.returning) - + returning_clause = self.returning_clause( + insert_stmt, self.returning or insert_stmt._returning) + if self.returning_precedes_values: text += " " + returning_clause + else: + returning_clause = None - if not colparams and supports_default_values: + if insert_stmt.select is not None: + text += " %s" % self.process(self._insert_from_select, **kw) + elif not crud_params and supports_default_values: text += " DEFAULT VALUES" + elif insert_stmt._has_multi_parameters: + text += " VALUES %s" % ( + ", ".join( + "(%s)" % ( + ', '.join(c[1] for c in crud_param_set) + ) + for crud_param_set in crud_params + ) + ) else: text += " VALUES (%s)" % \ - ', '.join([c[1] for c in colparams]) - - if self.returning and not self.returning_precedes_values: + ', '.join([c[1] for c in crud_params]) + + if insert_stmt._post_values_clause is not None: + post_values_clause = self.process( + insert_stmt._post_values_clause, **kw) + if post_values_clause: + text += " " + post_values_clause + + if returning_clause and not self.returning_precedes_values: text += " " + returning_clause - - return text - - def visit_update(self, update_stmt): - self.stack.append({'from': set([update_stmt.table])}) - self.isupdate = True - colparams = self._get_colparams(update_stmt) + if self.ctes and toplevel: + text = self._render_cte_clause() + text - text = "UPDATE " + self.preparer.format_table(update_stmt.table) - - text += ' SET ' + \ - ', '.join( - self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] - for c in colparams - ) - - if update_stmt._returning: - self.returning = update_stmt._returning - if self.returning_precedes_values: - text += " " + self.returning_clause(update_stmt, update_stmt._returning) - - if update_stmt._whereclause is not None: - text += " WHERE " + self.process(update_stmt._whereclause) - - if self.returning and not self.returning_precedes_values: - text += " " + self.returning_clause(update_stmt, update_stmt._returning) - self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text - def _create_crud_bind_param(self, col, value, required=False): - bindparam = sql.bindparam(col.key, value, type_=col.type, required=required) - bindparam._is_crud = True - if col.key in self.binds: - raise exc.CompileError( - "Bind parameter name '%s' is reserved " - "for the VALUES or SET clause of this insert/update statement." - % col.key - ) - - self.binds[col.key] = bindparam - return self.bindparam_string(self._truncate_bindparam(bindparam)) - - def _get_colparams(self, stmt): - """create a set of tuples representing column/string pairs for use - in an INSERT or UPDATE statement. + def update_limit_clause(self, update_stmt): + """Provide a hook for MySQL to add LIMIT to the UPDATE""" + return None - Also generates the Compiled object's postfetch, prefetch, and returning - column collections, used for default handling and ultimately - populating the ResultProxy's prefetch_cols() and postfetch_cols() - collections. + def update_tables_clause(self, update_stmt, from_table, + extra_froms, **kw): + """Provide a hook to override the initial table clause + in an UPDATE statement. + + MySQL overrides this. """ + kw['asfrom'] = True + return from_table._compiler_dispatch(self, iscrud=True, **kw) - self.postfetch = [] - self.prefetch = [] - self.returning = [] + def update_from_clause(self, update_stmt, + from_table, extra_froms, + from_hints, + **kw): + """Provide a hook to override the generation of an + UPDATE..FROM clause. - # no parameters in the statement, no parameters in the - # compiled params - return binds for all columns - if self.column_keys is None and stmt.parameters is None: - return [ - (c, self._create_crud_bind_param(c, None, required=True)) - for c in stmt.table.columns - ] + MySQL and MSSQL override this. - required = object() - - # if we have statement parameters - set defaults in the - # compiled params - if self.column_keys is None: - parameters = {} + """ + return "FROM " + ', '.join( + t._compiler_dispatch(self, asfrom=True, + fromhints=from_hints, **kw) + for t in extra_froms) + + def visit_update(self, update_stmt, asfrom=False, **kw): + toplevel = not self.stack + + self.stack.append( + {'correlate_froms': set([update_stmt.table]), + "asfrom_froms": set([update_stmt.table]), + "selectable": update_stmt}) + + extra_froms = update_stmt._extra_froms + + text = "UPDATE " + + if update_stmt._prefixes: + text += self._generate_prefixes(update_stmt, + update_stmt._prefixes, **kw) + + table_text = self.update_tables_clause(update_stmt, update_stmt.table, + extra_froms, **kw) + + crud_params = crud._setup_crud_params( + self, update_stmt, crud.ISUPDATE, **kw) + + if update_stmt._hints: + dialect_hints, table_text = self._setup_crud_hints( + update_stmt, table_text) else: - parameters = dict((sql._column_as_key(key), required) - for key in self.column_keys - if not stmt.parameters or key not in stmt.parameters) + dialect_hints = None - if stmt.parameters is not None: - for k, v in stmt.parameters.iteritems(): - parameters.setdefault(sql._column_as_key(k), v) + text += table_text - # create a list of column assignment clauses as tuples - values = [] - - need_pks = self.isinsert and \ - not self.inline and \ - not stmt._returning - - implicit_returning = need_pks and \ - self.dialect.implicit_returning and \ - stmt.table.implicit_returning - - postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid - - # iterating through columns at the top to maintain ordering. - # otherwise we might iterate through individual sets of - # "defaults", "primary key cols", etc. - for c in stmt.table.columns: - if c.key in parameters: - value = parameters[c.key] - if sql._is_literal(value): - value = self._create_crud_bind_param(c, value, required=value is required) - else: - self.postfetch.append(c) - value = self.process(value.self_group()) - values.append((c, value)) - - elif self.isinsert: - if c.primary_key and \ - need_pks and \ - ( - implicit_returning or - not postfetch_lastrowid or - c is not stmt.table._autoincrement_column - ): - - if implicit_returning: - if c.default is not None: - if c.default.is_sequence: - proc = self.process(c.default) - if proc is not None: - values.append((c, proc)) - self.returning.append(c) - elif c.default.is_clause_element: - values.append((c, self.process(c.default.arg.self_group()))) - self.returning.append(c) - else: - values.append((c, self._create_crud_bind_param(c, None))) - self.prefetch.append(c) - else: - self.returning.append(c) - else: - if ( - c.default is not None and \ - ( - self.dialect.supports_sequences or - not c.default.is_sequence - ) - ) or self.dialect.preexecute_autoincrement_sequences: + text += ' SET ' + include_table = extra_froms and \ + self.render_table_with_column_in_update_from + text += ', '.join( + c[0]._compiler_dispatch(self, + include_table=include_table) + + '=' + c[1] for c in crud_params + ) - values.append((c, self._create_crud_bind_param(c, None))) - self.prefetch.append(c) - - elif c.default is not None: - if c.default.is_sequence: - proc = self.process(c.default) - if proc is not None: - values.append((c, proc)) - if not c.primary_key: - self.postfetch.append(c) - elif c.default.is_clause_element: - values.append((c, self.process(c.default.arg.self_group()))) - - if not c.primary_key: - # dont add primary key column to postfetch - self.postfetch.append(c) - else: - values.append((c, self._create_crud_bind_param(c, None))) - self.prefetch.append(c) - elif c.server_default is not None: - if not c.primary_key: - self.postfetch.append(c) - - elif self.isupdate: - if c.onupdate is not None and not c.onupdate.is_sequence: - if c.onupdate.is_clause_element: - values.append((c, self.process(c.onupdate.arg.self_group()))) - self.postfetch.append(c) - else: - values.append((c, self._create_crud_bind_param(c, None))) - self.prefetch.append(c) - elif c.server_onupdate is not None: - self.postfetch.append(c) - return values - - def visit_delete(self, delete_stmt): - self.stack.append({'from': set([delete_stmt.table])}) - self.isdelete = True - - text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) - - if delete_stmt._returning: - self.returning = delete_stmt._returning + if self.returning or update_stmt._returning: if self.returning_precedes_values: - text += " " + self.returning_clause(delete_stmt, delete_stmt._returning) - - if delete_stmt._whereclause is not None: - text += " WHERE " + self.process(delete_stmt._whereclause) + text += " " + self.returning_clause( + update_stmt, self.returning or update_stmt._returning) + + if extra_froms: + extra_from_text = self.update_from_clause( + update_stmt, + update_stmt.table, + extra_froms, + dialect_hints, **kw) + if extra_from_text: + text += " " + extra_from_text + + if update_stmt._whereclause is not None: + t = self.process(update_stmt._whereclause, **kw) + if t: + text += " WHERE " + t + + limit_clause = self.update_limit_clause(update_stmt) + if limit_clause: + text += " " + limit_clause + + if (self.returning or update_stmt._returning) and \ + not self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, self.returning or update_stmt._returning) + + if self.ctes and toplevel: + text = self._render_cte_clause() + text - if self.returning and not self.returning_precedes_values: - text += " " + self.returning_clause(delete_stmt, delete_stmt._returning) - self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text + + @util.memoized_property + def _key_getters_for_crud_column(self): + return crud._key_getters_for_crud_column(self, self.statement) + + def visit_delete(self, delete_stmt, asfrom=False, **kw): + toplevel = not self.stack + + self.stack.append({'correlate_froms': set([delete_stmt.table]), + "asfrom_froms": set([delete_stmt.table]), + "selectable": delete_stmt}) + + crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw) + + text = "DELETE " + + if delete_stmt._prefixes: + text += self._generate_prefixes(delete_stmt, + delete_stmt._prefixes, **kw) + + text += "FROM " + table_text = delete_stmt.table._compiler_dispatch( + self, asfrom=True, iscrud=True) + + if delete_stmt._hints: + dialect_hints, table_text = self._setup_crud_hints( + delete_stmt, table_text) + + text += table_text + + if delete_stmt._returning: + if self.returning_precedes_values: + text += " " + self.returning_clause( + delete_stmt, delete_stmt._returning) + + if delete_stmt._whereclause is not None: + t = delete_stmt._whereclause._compiler_dispatch(self, **kw) + if t: + text += " WHERE " + t + + if delete_stmt._returning and not self.returning_precedes_values: + text += " " + self.returning_clause( + delete_stmt, delete_stmt._returning) + + if self.ctes and toplevel: + text = self._render_cte_clause() + text + + self.stack.pop(-1) + + if asfrom: + return "(" + text + ")" + else: + return text def visit_savepoint(self, savepoint_stmt): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_rollback_to_savepoint(self, savepoint_stmt): - return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + return "ROLLBACK TO SAVEPOINT %s" % \ + self.preparer.format_savepoint(savepoint_stmt) def visit_release_savepoint(self, savepoint_stmt): - return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + return "RELEASE SAVEPOINT %s" % \ + self.preparer.format_savepoint(savepoint_stmt) -class DDLCompiler(engine.Compiled): - +class StrSQLCompiler(SQLCompiler): + """"a compiler subclass with a few non-standard SQL features allowed. + + Used for stringification of SQL statements when a real dialect is not + available. + + """ + + def _fallback_column_name(self, column): + return "" + + def visit_getitem_binary(self, binary, operator, **kw): + return "%s[%s]" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw) + ) + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + return self.visit_getitem_binary(binary, operator, **kw) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + return self.visit_getitem_binary(binary, operator, **kw) + + def returning_clause(self, stmt, returning_cols): + columns = [ + self._label_select_column(None, c, True, False, {}) + for c in elements._select_iterables(returning_cols) + ] + + return 'RETURNING ' + ', '.join(columns) + + +class DDLCompiler(Compiled): + @util.memoized_property def sql_compiler(self): - return self.dialect.statement_compiler(self.dialect, self.statement) - - @property - def preparer(self): - return self.dialect.identifier_preparer + return self.dialect.statement_compiler(self.dialect, None) + + @util.memoized_property + def type_compiler(self): + return self.dialect.type_compiler def construct_params(self, params=None): return None - + def visit_ddl(self, ddl, **kwargs): # table events can substitute table and schema name context = ddl.context if isinstance(ddl.target, schema.Table): context = context.copy() - preparer = self.dialect.identifier_preparer + preparer = self.preparer path = preparer.format_table_seq(ddl.target) if len(path) == 1: table, sch = path[0], '' @@ -1106,113 +2288,230 @@ class DDLCompiler(engine.Compiled): context.setdefault('table', table) context.setdefault('schema', sch) context.setdefault('fullname', preparer.format_table(ddl.target)) - - return ddl.statement % context + + return self.sql_compiler.post_process_text(ddl.statement % context) + + def visit_create_schema(self, create): + schema = self.preparer.format_schema(create.element) + return "CREATE SCHEMA " + schema + + def visit_drop_schema(self, drop): + schema = self.preparer.format_schema(drop.element) + text = "DROP SCHEMA " + schema + if drop.cascade: + text += " CASCADE" + return text def visit_create_table(self, create): table = create.element - preparer = self.dialect.identifier_preparer + preparer = self.preparer + + text = "\nCREATE " + if table._prefixes: + text += " ".join(table._prefixes) + " " + text += "TABLE " + preparer.format_table(table) + " " + + create_table_suffix = self.create_table_suffix(table) + if create_table_suffix: + text += create_table_suffix + " " + + text += "(" - text = "\n" + " ".join(['CREATE'] + \ - table._prefixes + \ - ['TABLE', - preparer.format_table(table), - "("]) separator = "\n" # if only one primary key, specify it along with the column first_pk = False - for column in table.columns: - text += separator - separator = ", \n" - text += "\t" + self.get_column_specification( - column, - first_pk=column.primary_key and not first_pk - ) - if column.primary_key: - first_pk = True - const = " ".join(self.process(constraint) for constraint in column.constraints) - if const: - text += " " + const + for create_column in create.columns: + column = create_column.element + try: + processed = self.process(create_column, + first_pk=column.primary_key + and not first_pk) + if processed is not None: + text += separator + separator = ", \n" + text += "\t" + processed + if column.primary_key: + first_pk = True + except exc.CompileError as ce: + util.raise_from_cause( + exc.CompileError( + util.u("(in table '%s', column '%s'): %s") % + (table.description, column.name, ce.args[0]) + )) - const = self.create_table_constraints(table) + const = self.create_table_constraints( + table, _include_foreign_key_constraints= # noqa + create.include_foreign_key_constraints) if const: - text += ", \n\t" + const + text += separator + "\t" + const text += "\n)%s\n\n" % self.post_create_table(table) return text - def create_table_constraints(self, table): - + def visit_create_column(self, create, first_pk=False): + column = create.element + + if column.system: + return None + + text = self.get_column_specification( + column, + first_pk=first_pk + ) + const = " ".join(self.process(constraint) + for constraint in column.constraints) + if const: + text += " " + const + + return text + + def create_table_constraints( + self, table, + _include_foreign_key_constraints=None): + # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) constraints = [] if table.primary_key: constraints.append(table.primary_key) - - constraints.extend([c for c in table.constraints if c is not table.primary_key]) - - return ", \n\t".join(p for p in - (self.process(constraint) for constraint in constraints - if ( - constraint._create_rule is None or - constraint._create_rule(self)) - and ( - not self.dialect.supports_alter or - not getattr(constraint, 'use_alter', False) - )) if p is not None - ) - + + all_fkcs = table.foreign_key_constraints + if _include_foreign_key_constraints is not None: + omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints) + else: + omit_fkcs = set() + + constraints.extend([c for c in table._sorted_constraints + if c is not table.primary_key and + c not in omit_fkcs]) + + return ", \n\t".join( + p for p in + (self.process(constraint) + for constraint in constraints + if ( + constraint._create_rule is None or + constraint._create_rule(self)) + and ( + not self.dialect.supports_alter or + not getattr(constraint, 'use_alter', False) + )) if p is not None + ) + def visit_drop_table(self, drop): return "\nDROP TABLE " + self.preparer.format_table(drop.element) - - def visit_create_index(self, create): + + def visit_drop_view(self, drop): + return "\nDROP VIEW " + self.preparer.format_table(drop.element) + + def _verify_index_table(self, index): + if index.table is None: + raise exc.CompileError("Index '%s' is not associated " + "with any table." % index.name) + + def visit_create_index(self, create, include_schema=False, + include_table_schema=True): index = create.element + self._verify_index_table(index) preparer = self.preparer text = "CREATE " if index.unique: text += "UNIQUE " text += "INDEX %s ON %s (%s)" \ - % (preparer.quote(self._validate_identifier(index.name, True), index.quote), - preparer.format_table(index.table), - ', '.join(preparer.quote(c.name, c.quote) - for c in index.columns)) + % ( + self._prepared_index_name(index, + include_schema=include_schema), + preparer.format_table(index.table, + use_schema=include_table_schema), + ', '.join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True) for + expr in index.expressions) + ) return text def visit_drop_index(self, drop): index = drop.element - return "\nDROP INDEX " + \ - self.preparer.quote(self._validate_identifier(index.name, False), index.quote) + return "\nDROP INDEX " + self._prepared_index_name( + index, include_schema=True) + + def _prepared_index_name(self, index, include_schema=False): + if index.table is not None: + effective_schema = self.preparer.schema_for_object(index.table) + else: + effective_schema = None + if include_schema and effective_schema: + schema_name = self.preparer.quote_schema(effective_schema) + else: + schema_name = None + + ident = index.name + if isinstance(ident, elements._truncated_label): + max_ = self.dialect.max_index_name_length or \ + self.dialect.max_identifier_length + if len(ident) > max_: + ident = ident[0:max_ - 8] + \ + "_" + util.md5_hex(ident)[-4:] + else: + self.dialect.validate_identifier(ident) + + index_name = self.preparer.quote(ident) + + if schema_name: + index_name = schema_name + "." + index_name + return index_name def visit_add_constraint(self, create): - preparer = self.preparer return "ALTER TABLE %s ADD %s" % ( self.preparer.format_table(create.element.table), self.process(create.element) ) def visit_create_sequence(self, create): - text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) + text = "CREATE SEQUENCE %s" % \ + self.preparer.format_sequence(create.element) if create.element.increment is not None: text += " INCREMENT BY %d" % create.element.increment if create.element.start is not None: text += " START WITH %d" % create.element.start + if create.element.minvalue is not None: + text += " MINVALUE %d" % create.element.minvalue + if create.element.maxvalue is not None: + text += " MAXVALUE %d" % create.element.maxvalue + if create.element.nominvalue is not None: + text += " NO MINVALUE" + if create.element.nomaxvalue is not None: + text += " NO MAXVALUE" + if create.element.cycle is not None: + text += " CYCLE" return text - + def visit_drop_sequence(self, drop): - return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + return "DROP SEQUENCE %s" % \ + self.preparer.format_sequence(drop.element) def visit_drop_constraint(self, drop): - preparer = self.preparer + constraint = drop.element + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + else: + formatted_name = None + + if formatted_name is None: + raise exc.CompileError( + "Can't emit DROP CONSTRAINT for constraint %r; " + "it has no name" % drop.element) return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( self.preparer.format_table(drop.element.table), - self.preparer.format_constraint(drop.element), + formatted_name, drop.cascade and " CASCADE" or "" ) - + def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process(column.type) + self.dialect.type_compiler.process( + column.type, type_expression=column) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -1221,42 +2520,42 @@ class DDLCompiler(engine.Compiled): colspec += " NOT NULL" return colspec + def create_table_suffix(self, table): + return '' + def post_create_table(self, table): return '' - def _validate_identifier(self, ident, truncate): - if truncate: - if len(ident) > self.dialect.max_identifier_length: - counter = getattr(self, 'counter', 0) - self.counter = counter + 1 - return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:] - else: - return ident - else: - self.dialect.validate_identifier(ident) - return ident - def get_column_default_string(self, column): if isinstance(column.server_default, schema.DefaultClause): - if isinstance(column.server_default.arg, basestring): - return "'%s'" % column.server_default.arg + if isinstance(column.server_default.arg, util.string_types): + return self.sql_compiler.render_literal_value( + column.server_default.arg, sqltypes.STRINGTYPE) else: - return self.sql_compiler.process(column.server_default.arg) + return self.sql_compiler.process( + column.server_default.arg, literal_binds=True) else: return None def visit_check_constraint(self, constraint): text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % \ - self.preparer.format_constraint(constraint) - sqltext = sql_util.expression_as_ddl(constraint.sqltext) - text += "CHECK (%s)" % self.sql_compiler.process(sqltext) + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext, + include_table=False, + literal_binds=True) text += self.define_constraint_deferrability(constraint) return text def visit_column_check_constraint(self, constraint): - text = " CHECK (%s)" % constraint.sqltext + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "CHECK (%s)" % constraint.sqltext text += self.define_constraint_deferrability(constraint) return text @@ -1265,36 +2564,53 @@ class DDLCompiler(engine.Compiled): return '' text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name text += "PRIMARY KEY " - text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) - for c in constraint) + text += "(%s)" % ', '.join(self.preparer.quote(c.name) + for c in (constraint.columns_autoinc_first + if constraint._implicit_generated + else constraint.columns)) text += self.define_constraint_deferrability(constraint) return text def visit_foreign_key_constraint(self, constraint): - preparer = self.dialect.identifier_preparer + preparer = self.preparer text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % \ - preparer.format_constraint(constraint) - remote_table = list(constraint._elements.values())[0].column.table + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + remote_table = list(constraint.elements)[0].column.table text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - ', '.join(preparer.quote(f.parent.name, f.parent.quote) - for f in constraint._elements.values()), - preparer.format_table(remote_table), - ', '.join(preparer.quote(f.column.name, f.column.quote) - for f in constraint._elements.values()) + ', '.join(preparer.quote(f.parent.name) + for f in constraint.elements), + self.define_constraint_remote_table( + constraint, remote_table, preparer), + ', '.join(preparer.quote(f.column.name) + for f in constraint.elements) ) + text += self.define_constraint_match(constraint) text += self.define_constraint_cascades(constraint) text += self.define_constraint_deferrability(constraint) return text + def define_constraint_remote_table(self, constraint, table, preparer): + """Format the remote table clause of a CREATE CONSTRAINT clause.""" + + return preparer.format_table(table) + def visit_unique_constraint(self, constraint): + if len(constraint) == 0: + return '' text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) - text += " UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)) + formatted_name = self.preparer.format_constraint(constraint) + text += "CONSTRAINT %s " % formatted_name + text += "UNIQUE (%s)" % ( + ', '.join(self.preparer.quote(c.name) + for c in constraint)) text += self.define_constraint_deferrability(constraint) return text @@ -1305,7 +2621,7 @@ class DDLCompiler(engine.Compiled): if constraint.onupdate is not None: text += " ON UPDATE %s" % constraint.onupdate return text - + def define_constraint_deferrability(self, constraint): text = "" if constraint.deferrable is not None: @@ -1316,132 +2632,180 @@ class DDLCompiler(engine.Compiled): if constraint.initially is not None: text += " INITIALLY %s" % constraint.initially return text - - -class GenericTypeCompiler(engine.TypeCompiler): - def visit_CHAR(self, type_): - return "CHAR" + (type_.length and "(%d)" % type_.length or "") - def visit_NCHAR(self, type_): - return "NCHAR" + (type_.length and "(%d)" % type_.length or "") - - def visit_FLOAT(self, type_): + def define_constraint_match(self, constraint): + text = "" + if constraint.match is not None: + text += " MATCH %s" % constraint.match + return text + + +class GenericTypeCompiler(TypeCompiler): + + def visit_FLOAT(self, type_, **kw): return "FLOAT" - def visit_NUMERIC(self, type_): + def visit_REAL(self, type_, **kw): + return "REAL" + + def visit_NUMERIC(self, type_, **kw): if type_.precision is None: return "NUMERIC" elif type_.scale is None: - return "NUMERIC(%(precision)s)" % {'precision': type_.precision} + return "NUMERIC(%(precision)s)" % \ + {'precision': type_.precision} else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale} + return "NUMERIC(%(precision)s, %(scale)s)" % \ + {'precision': type_.precision, + 'scale': type_.scale} - def visit_DECIMAL(self, type_): - return "DECIMAL" - - def visit_INTEGER(self, type_): + def visit_DECIMAL(self, type_, **kw): + if type_.precision is None: + return "DECIMAL" + elif type_.scale is None: + return "DECIMAL(%(precision)s)" % \ + {'precision': type_.precision} + else: + return "DECIMAL(%(precision)s, %(scale)s)" % \ + {'precision': type_.precision, + 'scale': type_.scale} + + def visit_INTEGER(self, type_, **kw): return "INTEGER" - def visit_SMALLINT(self, type_): + def visit_SMALLINT(self, type_, **kw): return "SMALLINT" - def visit_BIGINT(self, type_): + def visit_BIGINT(self, type_, **kw): return "BIGINT" - def visit_TIMESTAMP(self, type_): + def visit_TIMESTAMP(self, type_, **kw): return 'TIMESTAMP' - def visit_DATETIME(self, type_): + def visit_DATETIME(self, type_, **kw): return "DATETIME" - def visit_DATE(self, type_): + def visit_DATE(self, type_, **kw): return "DATE" - def visit_TIME(self, type_): + def visit_TIME(self, type_, **kw): return "TIME" - def visit_CLOB(self, type_): + def visit_CLOB(self, type_, **kw): return "CLOB" - def visit_NCLOB(self, type_): + def visit_NCLOB(self, type_, **kw): return "NCLOB" - def visit_VARCHAR(self, type_): - return "VARCHAR" + (type_.length and "(%d)" % type_.length or "") + def _render_string_type(self, type_, name): - def visit_NVARCHAR(self, type_): - return "NVARCHAR" + (type_.length and "(%d)" % type_.length or "") + text = name + if type_.length: + text += "(%d)" % type_.length + if type_.collation: + text += ' COLLATE "%s"' % type_.collation + return text - def visit_BLOB(self, type_): + def visit_CHAR(self, type_, **kw): + return self._render_string_type(type_, "CHAR") + + def visit_NCHAR(self, type_, **kw): + return self._render_string_type(type_, "NCHAR") + + def visit_VARCHAR(self, type_, **kw): + return self._render_string_type(type_, "VARCHAR") + + def visit_NVARCHAR(self, type_, **kw): + return self._render_string_type(type_, "NVARCHAR") + + def visit_TEXT(self, type_, **kw): + return self._render_string_type(type_, "TEXT") + + def visit_BLOB(self, type_, **kw): return "BLOB" - def visit_BINARY(self, type_): + def visit_BINARY(self, type_, **kw): return "BINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_VARBINARY(self, type_): + def visit_VARBINARY(self, type_, **kw): return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") - - def visit_BOOLEAN(self, type_): + + def visit_BOOLEAN(self, type_, **kw): return "BOOLEAN" - - def visit_TEXT(self, type_): - return "TEXT" - - def visit_large_binary(self, type_): - return self.visit_BLOB(type_) - - def visit_boolean(self, type_): - return self.visit_BOOLEAN(type_) - - def visit_time(self, type_): - return self.visit_TIME(type_) - - def visit_datetime(self, type_): - return self.visit_DATETIME(type_) - - def visit_date(self, type_): - return self.visit_DATE(type_) - def visit_big_integer(self, type_): - return self.visit_BIGINT(type_) - - def visit_small_integer(self, type_): - return self.visit_SMALLINT(type_) - - def visit_integer(self, type_): - return self.visit_INTEGER(type_) - - def visit_float(self, type_): - return self.visit_FLOAT(type_) - - def visit_numeric(self, type_): - return self.visit_NUMERIC(type_) - - def visit_string(self, type_): - return self.visit_VARCHAR(type_) - - def visit_unicode(self, type_): - return self.visit_VARCHAR(type_) + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) + + def visit_boolean(self, type_, **kw): + return self.visit_BOOLEAN(type_, **kw) + + def visit_time(self, type_, **kw): + return self.visit_TIME(type_, **kw) + + def visit_datetime(self, type_, **kw): + return self.visit_DATETIME(type_, **kw) + + def visit_date(self, type_, **kw): + return self.visit_DATE(type_, **kw) + + def visit_big_integer(self, type_, **kw): + return self.visit_BIGINT(type_, **kw) + + def visit_small_integer(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) + + def visit_integer(self, type_, **kw): + return self.visit_INTEGER(type_, **kw) + + def visit_real(self, type_, **kw): + return self.visit_REAL(type_, **kw) + + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) + + def visit_numeric(self, type_, **kw): + return self.visit_NUMERIC(type_, **kw) + + def visit_string(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_unicode(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) + + def visit_unicode_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) + + def visit_enum(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_null(self, type_, **kw): + raise exc.CompileError("Can't generate DDL for %r; " + "did you forget to specify a " + "type on this Column?" % type_) + + def visit_type_decorator(self, type_, **kw): + return self.process(type_.type_engine(self.dialect), **kw) + + def visit_user_defined(self, type_, **kw): + return type_.get_col_spec(**kw) + + +class StrSQLTypeCompiler(GenericTypeCompiler): + def __getattr__(self, key): + if key.startswith("visit_"): + return self._visit_unknown + else: + raise AttributeError(key) + + def _visit_unknown(self, type_, **kw): + return "%s" % type_.__class__.__name__ - def visit_text(self, type_): - return self.visit_TEXT(type_) - def visit_unicode_text(self, type_): - return self.visit_TEXT(type_) - - def visit_enum(self, type_): - return self.visit_VARCHAR(type_) - - def visit_null(self, type_): - raise NotImplementedError("Can't generate DDL for the null type") - - def visit_type_decorator(self, type_): - return self.process(type_.type_engine(self.dialect)) - - def visit_user_defined(self, type_): - return type_.get_col_spec() - class IdentifierPreparer(object): + """Handle quoting and case-folding of identifiers based on options.""" reserved_words = RESERVED_WORDS @@ -1450,15 +2814,18 @@ class IdentifierPreparer(object): illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS - def __init__(self, dialect, initial_quote='"', - final_quote=None, escape_quote='"', omit_schema=False): + schema_for_object = schema._schema_getter(None) + + def __init__(self, dialect, initial_quote='"', + final_quote=None, escape_quote='"', omit_schema=False): """Construct a new ``IdentifierPreparer`` object. initial_quote Character that begins a delimited identifier. final_quote - Character that ends a delimited identifier. Defaults to `initial_quote`. + Character that ends a delimited identifier. Defaults to + `initial_quote`. omit_schema Prevent prepending schema name. Useful for databases that do @@ -1472,7 +2839,13 @@ class IdentifierPreparer(object): self.escape_to_quote = self.escape_quote * 2 self.omit_schema = omit_schema self._strings = {} - + + def _with_schema_translate(self, schema_translate_map): + prep = self.__class__.__new__(self.__class__) + prep.__dict__.update(self.__dict__) + prep.schema_for_object = schema._schema_getter(schema_translate_map) + return prep + def _escape_identifier(self, value): """Escape an identifier. @@ -1498,25 +2871,37 @@ class IdentifierPreparer(object): quoting behavior. """ - return self.initial_quote + self._escape_identifier(value) + self.final_quote + return self.initial_quote + \ + self._escape_identifier(value) + \ + self.final_quote def _requires_quotes(self, value): """Return True if the given identifier requires quoting.""" lc_value = value.lower() return (lc_value in self.reserved_words or value[0] in self.illegal_initial_characters - or not self.legal_characters.match(unicode(value)) + or not self.legal_characters.match(util.text_type(value)) or (lc_value != value)) - def quote_schema(self, schema, force): - """Quote a schema. + def quote_schema(self, schema, force=None): + """Conditionally quote a schema. + + Subclasses can override this to provide database-dependent + quoting behavior for schema names. + + the 'force' flag should be considered deprecated. - Subclasses should override this to provide database-dependent - quoting behavior. """ return self.quote(schema, force) - def quote(self, ident, force): + def quote(self, ident, force=None): + """Conditionally quote an identifier. + + the 'force' flag should be considered deprecated. + """ + + force = getattr(ident, "quote", None) + if force is None: if ident in self._strings: return self._strings[ident] @@ -1532,47 +2917,81 @@ class IdentifierPreparer(object): return ident def format_sequence(self, sequence, use_schema=True): - name = self.quote(sequence.name, sequence.quote) - if not self.omit_schema and use_schema and sequence.schema is not None: - name = self.quote_schema(sequence.schema, sequence.quote) + "." + name + name = self.quote(sequence.name) + + effective_schema = self.schema_for_object(sequence) + + if (not self.omit_schema and use_schema and + effective_schema is not None): + name = self.quote_schema(effective_schema) + "." + name return name def format_label(self, label, name=None): - return self.quote(name or label.name, label.quote) + return self.quote(name or label.name) def format_alias(self, alias, name=None): - return self.quote(name or alias.name, alias.quote) + return self.quote(name or alias.name) def format_savepoint(self, savepoint, name=None): - return self.quote(name or savepoint.ident, savepoint.quote) + # Running the savepoint name through quoting is unnecessary + # for all known dialects. This is here to support potential + # third party use cases + ident = name or savepoint.ident + if self._requires_quotes(ident): + ident = self.quote_identifier(ident) + return ident + + @util.dependencies("sqlalchemy.sql.naming") + def format_constraint(self, naming, constraint): + if isinstance(constraint.name, elements._defer_name): + name = naming._constraint_name_for_table( + constraint, constraint.table) + if name: + return self.quote(name) + elif isinstance(constraint.name, elements._defer_none_name): + return None + return self.quote(constraint.name) - def format_constraint(self, constraint): - return self.quote(constraint.name, constraint.quote) - def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" if name is None: name = table.name - result = self.quote(name, table.quote) - if not self.omit_schema and use_schema and getattr(table, "schema", None): - result = self.quote_schema(table.schema, table.quote_schema) + "." + result + result = self.quote(name) + + effective_schema = self.schema_for_object(table) + + if not self.omit_schema and use_schema \ + and effective_schema: + result = self.quote_schema(effective_schema) + "." + result return result - def format_column(self, column, use_table=False, name=None, table_name=None): + def format_schema(self, name, quote=None): + """Prepare a quoted schema name.""" + + return self.quote(name, quote) + + def format_column(self, column, use_table=False, + name=None, table_name=None): """Prepare a quoted column name.""" if name is None: name = column.name if not getattr(column, 'is_literal', False): if use_table: - return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote) + return self.format_table( + column.table, use_schema=False, + name=table_name) + "." + self.quote(name) else: - return self.quote(name, column.quote) + return self.quote(name) else: - # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted + # literal textual elements get stuck into ColumnClause a lot, + # which shouldn't get quoted + if use_table: - return self.format_table(column.table, use_schema=False, name=table_name) + "." + name + return self.format_table( + column.table, use_schema=False, + name=table_name) + '.' + name else: return name @@ -1583,8 +3002,11 @@ class IdentifierPreparer(object): # ('database', 'owner', etc.) could override this and return # a longer sequence. - if not self.omit_schema and use_schema and getattr(table, 'schema', None): - return (self.quote_schema(table.schema, table.quote_schema), + effective_schema = self.schema_for_object(table) + + if not self.omit_schema and use_schema and \ + effective_schema: + return (self.quote_schema(effective_schema), self.format_table(table, use_schema=False)) else: return (self.format_table(table, use_schema=False), ) @@ -1592,18 +3014,18 @@ class IdentifierPreparer(object): @util.memoized_property def _r_identifiers(self): initial, final, escaped_final = \ - [re.escape(s) for s in - (self.initial_quote, self.final_quote, - self._escape_identifier(self.final_quote))] + [re.escape(s) for s in + (self.initial_quote, self.final_quote, + self._escape_identifier(self.final_quote))] r = re.compile( r'(?:' r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' r'|([^\.]+))(?=\.|$))+' % - { 'initial': initial, - 'final': final, - 'escaped': escaped_final }) + {'initial': initial, + 'final': final, + 'escaped': escaped_final}) return r - + def unformat_identifiers(self, identifiers): """Unpack 'schema.table.column'-like strings into components.""" diff --git a/sqlalchemy/sql/expression.py b/sqlalchemy/sql/expression.py index 3aaa06f..172bf4b 100644 --- a/sqlalchemy/sql/expression.py +++ b/sqlalchemy/sql/expression.py @@ -1,4258 +1,144 @@ -# expression.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# sql/expression.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Defines the base components of SQL expression trees. +"""Defines the public namespace for SQL expression constructs. -All components are derived from a common base class -:class:`ClauseElement`. Common behaviors are organized -based on class hierarchies, in some cases via mixins. - -All object construction from this package occurs via functions which -in some cases will construct composite :class:`ClauseElement` structures -together, and in other cases simply return a single :class:`ClauseElement` -constructed directly. The function interface affords a more "DSL-ish" -feel to constructing SQL expressions and also allows future class -reorganizations. - -Even though classes are not constructed directly from the outside, -most classes which have additional public methods are considered to be -public (i.e. have no leading underscore). Other classes which are -"semi-public" are marked with a single leading underscore; these -classes usually have few or no public methods and are less guaranteed -to stay the same in future releases. +Prior to version 0.9, this module contained all of "elements", "dml", +"default_comparator" and "selectable". The module was broken up +and most "factory" functions were moved to be grouped with their associated +class. """ -import itertools, re -from operator import attrgetter - -from sqlalchemy import util, exc #, types as sqltypes -from sqlalchemy.sql import operators -from sqlalchemy.sql.visitors import Visitable, cloned_traverse -import operator - -functions, schema, sql_util, sqltypes = None, None, None, None -DefaultDialect, ClauseAdapter, Annotated = None, None, None - __all__ = [ - 'Alias', 'ClauseElement', - 'ColumnCollection', 'ColumnElement', - 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', - 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', - 'between', 'bindparam', 'case', 'cast', 'column', 'delete', - 'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', - 'modifier', 'collate', - 'insert', 'intersect', 'intersect_all', 'join', 'label', 'literal', - 'literal_column', 'not_', 'null', 'or_', 'outparam', 'outerjoin', 'select', - 'subquery', 'table', 'text', 'tuple_', 'union', 'union_all', 'update', ] - -PARSE_AUTOCOMMIT = util._symbol('PARSE_AUTOCOMMIT') - -def desc(column): - """Return a descending ``ORDER BY`` clause element. - - e.g.:: - - order_by = [desc(table1.mycol)] - - """ - return _UnaryExpression(column, modifier=operators.desc_op) - -def asc(column): - """Return an ascending ``ORDER BY`` clause element. - - e.g.:: - - order_by = [asc(table1.mycol)] - - """ - return _UnaryExpression(column, modifier=operators.asc_op) - -def outerjoin(left, right, onclause=None): - """Return an ``OUTER JOIN`` clause element. - - The returned object is an instance of :class:`Join`. - - Similar functionality is also available via the :func:`outerjoin()` - method on any :class:`FromClause`. - - left - The left side of the join. - - right - The right side of the join. - - onclause - Optional criterion for the ``ON`` clause, is derived from - foreign key relationships established between left and right - otherwise. - - To chain joins together, use the :func:`join()` or :func:`outerjoin()` - methods on the resulting :class:`Join` object. - - """ - return Join(left, right, onclause, isouter=True) - -def join(left, right, onclause=None, isouter=False): - """Return a ``JOIN`` clause element (regular inner join). - - The returned object is an instance of :class:`Join`. - - Similar functionality is also available via the :func:`join()` method - on any :class:`FromClause`. - - left - The left side of the join. - - right - The right side of the join. - - onclause - Optional criterion for the ``ON`` clause, is derived from - foreign key relationships established between left and right - otherwise. - - To chain joins together, use the :func:`join()` or :func:`outerjoin()` - methods on the resulting :class:`Join` object. - - """ - return Join(left, right, onclause, isouter) - -def select(columns=None, whereclause=None, from_obj=[], **kwargs): - """Returns a ``SELECT`` clause element. - - Similar functionality is also available via the :func:`select()` - method on any :class:`FromClause`. - - The returned object is an instance of :class:`Select`. - - All arguments which accept :class:`ClauseElement` arguments also accept - string arguments, which will be converted as appropriate into - either :func:`text()` or :func:`literal_column()` constructs. - - columns - A list of :class:`ClauseElement` objects, typically :class:`ColumnElement` - objects or subclasses, which will form the columns clause of the - resulting statement. For all members which are instances of - :class:`Selectable`, the individual :class:`ColumnElement` members of the - :class:`Selectable` will be added individually to the columns clause. - For example, specifying a :class:`~sqlalchemy.schema.Table` instance will result in all - the contained :class:`~sqlalchemy.schema.Column` objects within to be added to the - columns clause. - - This argument is not present on the form of :func:`select()` - available on :class:`~sqlalchemy.schema.Table`. - - whereclause - A :class:`ClauseElement` expression which will be used to form the - ``WHERE`` clause. - - from_obj - A list of :class:`ClauseElement` objects which will be added to the - ``FROM`` clause of the resulting statement. Note that "from" - objects are automatically located within the columns and - whereclause ClauseElements. Use this parameter to explicitly - specify "from" objects which are not automatically locatable. - This could include :class:`~sqlalchemy.schema.Table` objects that aren't otherwise - present, or :class:`Join` objects whose presence will supercede that - of the :class:`~sqlalchemy.schema.Table` objects already located in the other clauses. - - \**kwargs - Additional parameters include: - - autocommit - Deprecated. Use .execution_options(autocommit=) - to set the autocommit option. - - prefixes - a list of strings or :class:`ClauseElement` objects to include - directly after the SELECT keyword in the generated statement, - for dialect-specific query features. - - distinct=False - when ``True``, applies a ``DISTINCT`` qualifier to the columns - clause of the resulting statement. - - use_labels=False - when ``True``, the statement will be generated using labels - for each column in the columns clause, which qualify each - column with its parent table's (or aliases) name so that name - conflicts between columns in different tables don't occur. - The format of the label is _. The "c" - collection of the resulting :class:`Select` object will use these - names as well for targeting column members. - - for_update=False - when ``True``, applies ``FOR UPDATE`` to the end of the - resulting statement. Certain database dialects also support - alternate values for this parameter, for example mysql - supports "read" which translates to ``LOCK IN SHARE MODE``, - and oracle supports "nowait" which translates to ``FOR UPDATE - NOWAIT``. - - correlate=True - indicates that this :class:`Select` object should have its - contained :class:`FromClause` elements "correlated" to an enclosing - :class:`Select` object. This means that any :class:`ClauseElement` - instance within the "froms" collection of this :class:`Select` - which is also present in the "froms" collection of an - enclosing select will not be rendered in the ``FROM`` clause - of this select statement. - - group_by - a list of :class:`ClauseElement` objects which will comprise the - ``GROUP BY`` clause of the resulting select. - - having - a :class:`ClauseElement` that will comprise the ``HAVING`` clause - of the resulting select when ``GROUP BY`` is used. - - order_by - a scalar or list of :class:`ClauseElement` objects which will - comprise the ``ORDER BY`` clause of the resulting select. - - limit=None - a numerical value which usually compiles to a ``LIMIT`` - expression in the resulting select. Databases that don't - support ``LIMIT`` will attempt to provide similar - functionality. - - offset=None - a numeric value which usually compiles to an ``OFFSET`` - expression in the resulting select. Databases that don't - support ``OFFSET`` will attempt to provide similar - functionality. - - bind=None - an ``Engine`` or ``Connection`` instance to which the - resulting ``Select ` object will be bound. The ``Select`` - object will otherwise automatically bind to whatever - ``Connectable`` instances can be located within its contained - :class:`ClauseElement` members. - - """ - return Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs) - -def subquery(alias, *args, **kwargs): - """Return an :class:`Alias` object derived - from a :class:`Select`. - - name - alias name - - \*args, \**kwargs - - all other arguments are delivered to the - :func:`select` function. - - """ - return Select(*args, **kwargs).alias(alias) - -def insert(table, values=None, inline=False, **kwargs): - """Return an :class:`Insert` clause element. - - Similar functionality is available via the :func:`insert()` method on - :class:`~sqlalchemy.schema.Table`. - - :param table: The table to be inserted into. - - :param values: A dictionary which specifies the column specifications of the - ``INSERT``, and is optional. If left as None, the column - specifications are determined from the bind parameters used - during the compile phase of the ``INSERT`` statement. If the - bind parameters also are None during the compile phase, then the - column specifications will be generated from the full list of - table columns. Note that the :meth:`~Insert.values()` generative method - may also be used for this. - - :param prefixes: A list of modifier keywords to be inserted between INSERT - and INTO. Alternatively, the :meth:`~Insert.prefix_with` generative method - may be used. - - :param inline: if True, SQL defaults will be compiled 'inline' into the - statement and not pre-executed. - - If both `values` and compile-time bind parameters are present, the - compile-time bind parameters override the information specified - within `values` on a per-key basis. - - The keys within `values` can be either :class:`~sqlalchemy.schema.Column` objects or their - string identifiers. Each key may reference one of: - - * a literal data value (i.e. string, number, etc.); - * a Column object; - * a SELECT statement. - - If a ``SELECT`` statement is specified which references this - ``INSERT`` statement's table, the statement will be correlated - against the ``INSERT`` statement. - - """ - return Insert(table, values, inline=inline, **kwargs) - -def update(table, whereclause=None, values=None, inline=False, **kwargs): - """Return an :class:`Update` clause element. - - Similar functionality is available via the :func:`update()` method on - :class:`~sqlalchemy.schema.Table`. - - :param table: The table to be updated. - - :param whereclause: A :class:`ClauseElement` describing the ``WHERE`` condition - of the ``UPDATE`` statement. Note that the :meth:`~Update.where()` - generative method may also be used for this. - - :param values: - A dictionary which specifies the ``SET`` conditions of the - ``UPDATE``, and is optional. If left as None, the ``SET`` - conditions are determined from the bind parameters used during - the compile phase of the ``UPDATE`` statement. If the bind - parameters also are None during the compile phase, then the - ``SET`` conditions will be generated from the full list of table - columns. Note that the :meth:`~Update.values()` generative method may - also be used for this. - - :param inline: - if True, SQL defaults will be compiled 'inline' into the statement - and not pre-executed. - - If both `values` and compile-time bind parameters are present, the - compile-time bind parameters override the information specified - within `values` on a per-key basis. - - The keys within `values` can be either :class:`~sqlalchemy.schema.Column` objects or their - string identifiers. Each key may reference one of: - - * a literal data value (i.e. string, number, etc.); - * a Column object; - * a SELECT statement. - - If a ``SELECT`` statement is specified which references this - ``UPDATE`` statement's table, the statement will be correlated - against the ``UPDATE`` statement. - - """ - return Update( - table, - whereclause=whereclause, - values=values, - inline=inline, - **kwargs) - -def delete(table, whereclause = None, **kwargs): - """Return a :class:`Delete` clause element. - - Similar functionality is available via the :func:`delete()` method on - :class:`~sqlalchemy.schema.Table`. - - :param table: The table to be updated. - - :param whereclause: A :class:`ClauseElement` describing the ``WHERE`` - condition of the ``UPDATE`` statement. Note that the :meth:`~Delete.where()` - generative method may be used instead. - - """ - return Delete(table, whereclause, **kwargs) - -def and_(*clauses): - """Join a list of clauses together using the ``AND`` operator. - - The ``&`` operator is also overloaded on all - :class:`_CompareMixin` subclasses to produce the - same result. - - """ - if len(clauses) == 1: - return clauses[0] - return BooleanClauseList(operator=operators.and_, *clauses) - -def or_(*clauses): - """Join a list of clauses together using the ``OR`` operator. - - The ``|`` operator is also overloaded on all - :class:`_CompareMixin` subclasses to produce the - same result. - - """ - if len(clauses) == 1: - return clauses[0] - return BooleanClauseList(operator=operators.or_, *clauses) - -def not_(clause): - """Return a negation of the given clause, i.e. ``NOT(clause)``. - - The ``~`` operator is also overloaded on all - :class:`_CompareMixin` subclasses to produce the - same result. - - """ - return operators.inv(_literal_as_binds(clause)) - -def distinct(expr): - """Return a ``DISTINCT`` clause.""" - expr = _literal_as_binds(expr) - return _UnaryExpression(expr, operator=operators.distinct_op, type_=expr.type) - -def between(ctest, cleft, cright): - """Return a ``BETWEEN`` predicate clause. - - Equivalent of SQL ``clausetest BETWEEN clauseleft AND clauseright``. - - The :func:`between()` method on all - :class:`_CompareMixin` subclasses provides - similar functionality. - - """ - ctest = _literal_as_binds(ctest) - return ctest.between(cleft, cright) - - -def case(whens, value=None, else_=None): - """Produce a ``CASE`` statement. - - whens - A sequence of pairs, or alternatively a dict, - to be translated into "WHEN / THEN" clauses. - - value - Optional for simple case statements, produces - a column expression as in "CASE WHEN ..." - - else\_ - Optional as well, for case defaults produces - the "ELSE" portion of the "CASE" statement. - - The expressions used for THEN and ELSE, - when specified as strings, will be interpreted - as bound values. To specify textual SQL expressions - for these, use the literal_column() or - text() construct. - - The expressions used for the WHEN criterion - may only be literal strings when "value" is - present, i.e. CASE table.somecol WHEN "x" THEN "y". - Otherwise, literal strings are not accepted - in this position, and either the text() - or literal() constructs must be used to - interpret raw string values. - - Usage examples:: - - case([(orderline.c.qty > 100, item.c.specialprice), - (orderline.c.qty > 10, item.c.bulkprice) - ], else_=item.c.regularprice) - case(value=emp.c.type, whens={ - 'engineer': emp.c.salary * 1.1, - 'manager': emp.c.salary * 3, - }) - - Using :func:`literal_column()`, to allow for databases that - do not support bind parameters in the ``then`` clause. The type - can be specified which determines the type of the :func:`case()` construct - overall:: - - case([(orderline.c.qty > 100, literal_column("'greaterthan100'", String)), - (orderline.c.qty > 10, literal_column("'greaterthan10'", String)) - ], else_=literal_column("'lethan10'", String)) - - """ - - return _Case(whens, value=value, else_=else_) - -def cast(clause, totype, **kwargs): - """Return a ``CAST`` function. - - Equivalent of SQL ``CAST(clause AS totype)``. - - Use with a :class:`~sqlalchemy.types.TypeEngine` subclass, i.e:: - - cast(table.c.unit_price * table.c.qty, Numeric(10,4)) - - or:: - - cast(table.c.timestamp, DATE) - - """ - return _Cast(clause, totype, **kwargs) - -def extract(field, expr): - """Return the clause ``extract(field FROM expr)``.""" - - return _Extract(field, expr) - -def collate(expression, collation): - """Return the clause ``expression COLLATE collation``.""" - - expr = _literal_as_binds(expression) - return _BinaryExpression( - expr, - _literal_as_text(collation), - operators.collate, type_=expr.type) - -def exists(*args, **kwargs): - """Return an ``EXISTS`` clause as applied to a :class:`Select` object. - - Calling styles are of the following forms:: - - # use on an existing select() - s = select([table.c.col1]).where(table.c.col2==5) - s = exists(s) - - # construct a select() at once - exists(['*'], **select_arguments).where(criterion) - - # columns argument is optional, generates "EXISTS (SELECT *)" - # by default. - exists().where(table.c.col2==5) - - """ - return _Exists(*args, **kwargs) - -def union(*selects, **kwargs): - """Return a ``UNION`` of multiple selectables. - - The returned object is an instance of - :class:`CompoundSelect`. - - A similar :func:`union()` method is available on all - :class:`FromClause` subclasses. - - \*selects - a list of :class:`Select` instances. - - \**kwargs - available keyword arguments are the same as those of - :func:`select`. - - """ - return CompoundSelect(CompoundSelect.UNION, *selects, **kwargs) - -def union_all(*selects, **kwargs): - """Return a ``UNION ALL`` of multiple selectables. - - The returned object is an instance of - :class:`CompoundSelect`. - - A similar :func:`union_all()` method is available on all - :class:`FromClause` subclasses. - - \*selects - a list of :class:`Select` instances. - - \**kwargs - available keyword arguments are the same as those of - :func:`select`. - - """ - return CompoundSelect(CompoundSelect.UNION_ALL, *selects, **kwargs) - -def except_(*selects, **kwargs): - """Return an ``EXCEPT`` of multiple selectables. - - The returned object is an instance of - :class:`CompoundSelect`. - - \*selects - a list of :class:`Select` instances. - - \**kwargs - available keyword arguments are the same as those of - :func:`select`. - - """ - return CompoundSelect(CompoundSelect.EXCEPT, *selects, **kwargs) - -def except_all(*selects, **kwargs): - """Return an ``EXCEPT ALL`` of multiple selectables. - - The returned object is an instance of - :class:`CompoundSelect`. - - \*selects - a list of :class:`Select` instances. - - \**kwargs - available keyword arguments are the same as those of - :func:`select`. - - """ - return CompoundSelect(CompoundSelect.EXCEPT_ALL, *selects, **kwargs) - -def intersect(*selects, **kwargs): - """Return an ``INTERSECT`` of multiple selectables. - - The returned object is an instance of - :class:`CompoundSelect`. - - \*selects - a list of :class:`Select` instances. - - \**kwargs - available keyword arguments are the same as those of - :func:`select`. - - """ - return CompoundSelect(CompoundSelect.INTERSECT, *selects, **kwargs) - -def intersect_all(*selects, **kwargs): - """Return an ``INTERSECT ALL`` of multiple selectables. - - The returned object is an instance of - :class:`CompoundSelect`. - - \*selects - a list of :class:`Select` instances. - - \**kwargs - available keyword arguments are the same as those of - :func:`select`. - - """ - return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs) - -def alias(selectable, alias=None): - """Return an :class:`Alias` object. - - An :class:`Alias` represents any :class:`FromClause` - with an alternate name assigned within SQL, typically using the ``AS`` - clause when generated, e.g. ``SELECT * FROM table AS aliasname``. - - Similar functionality is available via the :func:`alias()` method - available on all :class:`FromClause` subclasses. - - selectable - any :class:`FromClause` subclass, such as a table, select - statement, etc.. - - alias - string name to be assigned as the alias. If ``None``, a - random name will be generated. - - """ - return Alias(selectable, alias=alias) - - -def literal(value, type_=None): - """Return a literal clause, bound to a bind parameter. - - Literal clauses are created automatically when non- :class:`ClauseElement` - objects (such as strings, ints, dates, etc.) are used in a comparison - operation with a :class:`_CompareMixin` - subclass, such as a :class:`~sqlalchemy.schema.Column` object. Use this function to force the - generation of a literal clause, which will be created as a - :class:`_BindParamClause` with a bound value. - - :param value: the value to be bound. Can be any Python object supported by - the underlying DB-API, or is translatable via the given type argument. - - :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine` which - will provide bind-parameter translation for this literal. - - """ - return _BindParamClause(None, value, type_=type_, unique=True) - -def tuple_(*expr): - """Return a SQL tuple. - - Main usage is to produce a composite IN construct:: - - tuple_(table.c.col1, table.c.col2).in_( - [(1, 2), (5, 12), (10, 19)] - ) - - """ - return _Tuple(*expr) - -def label(name, obj): - """Return a :class:`_Label` object for the - given :class:`ColumnElement`. - - A label changes the name of an element in the columns clause of a - ``SELECT`` statement, typically via the ``AS`` SQL keyword. - - This functionality is more conveniently available via the - :func:`label()` method on :class:`ColumnElement`. - - name - label name - - obj - a :class:`ColumnElement`. - - """ - return _Label(name, obj) - -def column(text, type_=None): - """Return a textual column clause, as would be in the columns clause of a - ``SELECT`` statement. - - The object returned is an instance of - :class:`ColumnClause`, which represents the - "syntactical" portion of the schema-level - :class:`~sqlalchemy.schema.Column` object. - - text - the name of the column. Quoting rules will be applied to the - clause like any other column name. For textual column - constructs that are not to be quoted, use the - :func:`literal_column` function. - - type\_ - an optional :class:`~sqlalchemy.types.TypeEngine` object which will - provide result-set translation for this column. - - """ - return ColumnClause(text, type_=type_) - -def literal_column(text, type_=None): - """Return a textual column expression, as would be in the columns - clause of a ``SELECT`` statement. - - The object returned supports further expressions in the same way as any - other column object, including comparison, math and string operations. - The type\_ parameter is important to determine proper expression behavior - (such as, '+' means string concatenation or numerical addition based on - the type). - - text - the text of the expression; can be any SQL expression. Quoting rules - will not be applied. To specify a column-name expression which should - be subject to quoting rules, use the - :func:`column` function. - - type\_ - an optional :class:`~sqlalchemy.types.TypeEngine` object which will - provide result-set translation and additional expression semantics for - this column. If left as None the type will be NullType. - - """ - return ColumnClause(text, type_=type_, is_literal=True) - -def table(name, *columns): - """Return a :class:`TableClause` object. - - This is a primitive version of the :class:`~sqlalchemy.schema.Table` object, - which is a subclass of this object. - - """ - return TableClause(name, *columns) - -def bindparam(key, value=None, type_=None, unique=False, required=False): - """Create a bind parameter clause with the given key. - - value - a default value for this bind parameter. a bindparam with a - value is called a ``value-based bindparam``. - - type\_ - a sqlalchemy.types.TypeEngine object indicating the type of this - bind param, will invoke type-specific bind parameter processing - - unique - if True, bind params sharing the same name will have their - underlying ``key`` modified to a uniquely generated name. - mostly useful with value-based bind params. - - required - A value is required at execution time. - - """ - if isinstance(key, ColumnClause): - return _BindParamClause(key.name, value, type_=key.type, - unique=unique, required=required) - else: - return _BindParamClause(key, value, type_=type_, - unique=unique, required=required) - -def outparam(key, type_=None): - """Create an 'OUT' parameter for usage in functions (stored procedures), - for databases which support them. - - The ``outparam`` can be used like a regular function parameter. - The "output" value will be available from the - :class:`~sqlalchemy.engine.ResultProxy` object via its ``out_parameters`` - attribute, which returns a dictionary containing the values. - - """ - return _BindParamClause( - key, None, type_=type_, unique=False, isoutparam=True) - -def text(text, bind=None, *args, **kwargs): - """Create literal text to be inserted into a query. - - When constructing a query from a :func:`select()`, :func:`update()`, - :func:`insert()` or :func:`delete()`, using plain strings for argument - values will usually result in text objects being created - automatically. Use this function when creating textual clauses - outside of other :class:`ClauseElement` objects, or optionally wherever - plain text is to be used. - - text - the text of the SQL statement to be created. use ``:`` - to specify bind parameters; they will be compiled to their - engine-specific format. - - bind - an optional connection or engine to be used for this text query. - - autocommit=True - Deprecated. Use .execution_options(autocommit=) - to set the autocommit option. - - bindparams - a list of :func:`bindparam()` instances which can be used to define - the types and/or initial values for the bind parameters within - the textual statement; the keynames of the bindparams must match - those within the text of the statement. The types will be used - for pre-processing on bind values. - - typemap - a dictionary mapping the names of columns represented in the - ``SELECT`` clause of the textual statement to type objects, - which will be used to perform post-processing on columns within - the result set (for textual statements that produce result - sets). - - """ - return _TextClause(text, bind=bind, *args, **kwargs) - -def null(): - """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql - statement. - - """ - return _Null() - -class _FunctionGenerator(object): - """Generate :class:`Function` objects based on getattr calls.""" - - def __init__(self, **opts): - self.__names = [] - self.opts = opts - - def __getattr__(self, name): - # passthru __ attributes; fixes pydoc - if name.startswith('__'): - try: - return self.__dict__[name] - except KeyError: - raise AttributeError(name) - - elif name.endswith('_'): - name = name[0:-1] - f = _FunctionGenerator(**self.opts) - f.__names = list(self.__names) + [name] - return f - - def __call__(self, *c, **kwargs): - o = self.opts.copy() - o.update(kwargs) - if len(self.__names) == 1: - global functions - if functions is None: - from sqlalchemy.sql import functions - func = getattr(functions, self.__names[-1].lower(), None) - if func is not None: - return func(*c, **o) - - return Function( - self.__names[-1], packagenames=self.__names[0:-1], *c, **o) - -# "func" global - i.e. func.count() -func = _FunctionGenerator() - -# "modifier" global - i.e. modifier.distinct -# TODO: use UnaryExpression for this instead ? -modifier = _FunctionGenerator(group=False) - -class _generated_label(unicode): - """A unicode subclass used to identify dynamically generated names.""" - -def _escape_for_generated(x): - if isinstance(x, _generated_label): - return x - else: - return x.replace('%', '%%') - -def _clone(element): - return element._clone() - -def _expand_cloned(elements): - """expand the given set of ClauseElements to be the set of all 'cloned' - predecessors. - - """ - return itertools.chain(*[x._cloned_set for x in elements]) - -def _select_iterables(elements): - """expand tables into individual columns in the - given list of column expressions. - - """ - return itertools.chain(*[c._select_iterable for c in elements]) - -def _cloned_intersection(a, b): - """return the intersection of sets a and b, counting - any overlap between 'cloned' predecessors. - - The returned set is in terms of the enties present within 'a'. - - """ - all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set(elem for elem in a if all_overlap.intersection(elem._cloned_set)) - - -def _is_literal(element): - return not isinstance(element, Visitable) and \ - not hasattr(element, '__clause_element__') - -def _from_objects(*elements): - return itertools.chain(*[element._from_objects for element in elements]) - -def _labeled(element): - if not hasattr(element, 'name'): - return element.label(None) - else: - return element - -def _column_as_key(element): - if isinstance(element, basestring): - return element - if hasattr(element, '__clause_element__'): - element = element.__clause_element__() - return element.key - -def _literal_as_text(element): - if hasattr(element, '__clause_element__'): - return element.__clause_element__() - elif not isinstance(element, Visitable): - return _TextClause(unicode(element)) - else: - return element - -def _clause_element_as_expr(element): - if hasattr(element, '__clause_element__'): - return element.__clause_element__() - else: - return element - -def _literal_as_column(element): - if hasattr(element, '__clause_element__'): - return element.__clause_element__() - elif not isinstance(element, Visitable): - return literal_column(str(element)) - else: - return element - -def _literal_as_binds(element, name=None, type_=None): - if hasattr(element, '__clause_element__'): - return element.__clause_element__() - elif not isinstance(element, Visitable): - if element is None: - return null() - else: - return _BindParamClause(name, element, type_=type_, unique=True) - else: - return element - -def _type_from_args(args): - for a in args: - if not isinstance(a.type, sqltypes.NullType): - return a.type - else: - return sqltypes.NullType - -def _no_literals(element): - if hasattr(element, '__clause_element__'): - return element.__clause_element__() - elif not isinstance(element, Visitable): - raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' " - "function to indicate a SQL expression " - "literal, or 'literal()' to indicate a " - "bound value." % element) - else: - return element - -def _corresponding_column_or_error(fromclause, column, require_embedded=False): - c = fromclause.corresponding_column(column, - require_embedded=require_embedded) - if c is None: - raise exc.InvalidRequestError( - "Given column '%s', attached to table '%s', " - "failed to locate a corresponding column from table '%s'" - % - (column, - getattr(column, 'table', None),fromclause.description) - ) - return c - -@util.decorator -def _generative(fn, *args, **kw): - """Mark a method as generative.""" - - self = args[0]._generate() - fn(self, *args[1:], **kw) - return self - - -def is_column(col): - """True if ``col`` is an instance of :class:`ColumnElement`.""" - - return isinstance(col, ColumnElement) - - -class ClauseElement(Visitable): - """Base class for elements of a programmatically constructed SQL - expression. - - """ - __visit_name__ = 'clause' - - _annotations = {} - supports_execution = False - _from_objects = [] - _bind = None - - def _clone(self): - """Create a shallow copy of this ClauseElement. - - This method may be used by a generative API. Its also used as - part of the "deep" copy afforded by a traversal that combines - the _copy_internals() method. - - """ - c = self.__class__.__new__(self.__class__) - c.__dict__ = self.__dict__.copy() - c.__dict__.pop('_cloned_set', None) - - # this is a marker that helps to "equate" clauses to each other - # when a Select returns its list of FROM clauses. the cloning - # process leaves around a lot of remnants of the previous clause - # typically in the form of column expressions still attached to the - # old table. - c._is_clone_of = self - - return c - - @util.memoized_property - def _cloned_set(self): - """Return the set consisting all cloned anscestors of this - ClauseElement. - - Includes this ClauseElement. This accessor tends to be used for - FromClause objects to identify 'equivalent' FROM clauses, regardless - of transformative operations. - - """ - s = util.column_set() - f = self - while f is not None: - s.add(f) - f = getattr(f, '_is_clone_of', None) - return s - - def __getstate__(self): - d = self.__dict__.copy() - d.pop('_is_clone_of', None) - return d - - if util.jython: - def __hash__(self): - """Return a distinct hash code. - - ClauseElements may have special equality comparisons which - makes us rely on them having unique hash codes for use in - hash-based collections. Stock __hash__ doesn't guarantee - unique values on platforms with moving GCs. - """ - return id(self) - - def _annotate(self, values): - """return a copy of this ClauseElement with the given annotations - dictionary. - - """ - global Annotated - if Annotated is None: - from sqlalchemy.sql.util import Annotated - return Annotated(self, values) - - def _deannotate(self): - """return a copy of this ClauseElement with an empty annotations - dictionary. - - """ - return self._clone() - - def unique_params(self, *optionaldict, **kwargs): - """Return a copy with :func:`bindparam()` elments replaced. - - Same functionality as ``params()``, except adds `unique=True` - to affected bind parameters so that multiple statements can be - used. - - """ - return self._params(True, optionaldict, kwargs) - - def params(self, *optionaldict, **kwargs): - """Return a copy with :func:`bindparam()` elments replaced. - - Returns a copy of this ClauseElement with :func:`bindparam()` - elements replaced with values taken from the given dictionary:: - - >>> clause = column('x') + bindparam('foo') - >>> print clause.compile().params - {'foo':None} - >>> print clause.params({'foo':7}).compile().params - {'foo':7} - - """ - return self._params(False, optionaldict, kwargs) - - def _params(self, unique, optionaldict, kwargs): - if len(optionaldict) == 1: - kwargs.update(optionaldict[0]) - elif len(optionaldict) > 1: - raise exc.ArgumentError( - "params() takes zero or one positional dictionary argument") - - def visit_bindparam(bind): - if bind.key in kwargs: - bind.value = kwargs[bind.key] - if unique: - bind._convert_to_unique() - return cloned_traverse(self, {}, {'bindparam':visit_bindparam}) - - def compare(self, other, **kw): - """Compare this ClauseElement to the given ClauseElement. - - Subclasses should override the default behavior, which is a - straight identity comparison. - - \**kw are arguments consumed by subclass compare() methods and - may be used to modify the criteria for comparison. - (see :class:`ColumnElement`) - - """ - return self is other - - def _copy_internals(self, clone=_clone): - """Reassign internal elements to be clones of themselves. - - Called during a copy-and-traverse operation on newly - shallow-copied elements to create a deep copy. - - """ - pass - - def get_children(self, **kwargs): - """Return immediate child elements of this :class:`ClauseElement`. - - This is used for visit traversal. - - \**kwargs may contain flags that change the collection that is - returned, for example to return a subset of items in order to - cut down on larger traversals, or to return child items from a - different context (such as schema-level collections instead of - clause-level). - - """ - return [] - - def self_group(self, against=None): - return self - - # TODO: remove .bind as a method from the root ClauseElement. - # we should only be deriving binds from FromClause elements - # and certain SchemaItem subclasses. - # the "search_for_bind" functionality can still be used by - # execute(), however. - @property - def bind(self): - """Returns the Engine or Connection to which this ClauseElement is - bound, or None if none found. - - """ - if self._bind is not None: - return self._bind - - for f in _from_objects(self): - if f is self: - continue - engine = f.bind - if engine is not None: - return engine - else: - return None - - def execute(self, *multiparams, **params): - """Compile and execute this :class:`ClauseElement`.""" - - e = self.bind - if e is None: - label = getattr(self, 'description', self.__class__.__name__) - msg = ('This %s is not bound and does not support direct ' - 'execution. Supply this statement to a Connection or ' - 'Engine for execution. Or, assign a bind to the statement ' - 'or the Metadata of its underlying tables to enable ' - 'implicit execution via this method.' % label) - raise exc.UnboundExecutionError(msg) - return e._execute_clauseelement(self, multiparams, params) - - def scalar(self, *multiparams, **params): - """Compile and execute this :class:`ClauseElement`, returning the result's - scalar representation. - - """ - return self.execute(*multiparams, **params).scalar() - - def compile(self, bind=None, dialect=None, **kw): - """Compile this SQL expression. - - The return value is a :class:`~sqlalchemy.engine.Compiled` object. - Calling ``str()`` or ``unicode()`` on the returned value will yield a - string representation of the result. The - :class:`~sqlalchemy.engine.Compiled` object also can return a - dictionary of bind parameter names and values - using the ``params`` accessor. - - :param bind: An ``Engine`` or ``Connection`` from which a - ``Compiled`` will be acquired. This argument takes precedence over - this :class:`ClauseElement`'s bound engine, if any. - - :param column_keys: Used for INSERT and UPDATE statements, a list of - column names which should be present in the VALUES clause of the - compiled statement. If ``None``, all columns from the target table - object are rendered. - - :param dialect: A ``Dialect`` instance frmo which a ``Compiled`` - will be acquired. This argument takes precedence over the `bind` - argument as well as this :class:`ClauseElement`'s bound engine, if any. - - :param inline: Used for INSERT statements, for a dialect which does - not support inline retrieval of newly generated primary key - columns, will force the expression used to create the new primary - key value to be rendered inline within the INSERT statement's - VALUES clause. This typically refers to Sequence execution but may - also refer to any server-side default generation function - associated with a primary key `Column`. - - """ - - if not dialect: - if bind: - dialect = bind.dialect - elif self.bind: - dialect = self.bind.dialect - bind = self.bind - else: - global DefaultDialect - if DefaultDialect is None: - from sqlalchemy.engine.default import DefaultDialect - dialect = DefaultDialect() - compiler = self._compiler(dialect, bind=bind, **kw) - compiler.compile() - return compiler - - def _compiler(self, dialect, **kw): - """Return a compiler appropriate for this ClauseElement, given a Dialect.""" - - return dialect.statement_compiler(dialect, self, **kw) - - def __str__(self): - # Py3K - #return unicode(self.compile()) - # Py2K - return unicode(self.compile()).encode('ascii', 'backslashreplace') - # end Py2K - - def __and__(self, other): - return and_(self, other) - - def __or__(self, other): - return or_(self, other) - - def __invert__(self): - return self._negate() - - def __nonzero__(self): - raise TypeError("Boolean value of this clause is not defined") - - def _negate(self): - if hasattr(self, 'negation_clause'): - return self.negation_clause - else: - return _UnaryExpression( - self.self_group(against=operators.inv), - operator=operators.inv, - negate=None) - - def __repr__(self): - friendly = getattr(self, 'description', None) - if friendly is None: - return object.__repr__(self) - else: - return '<%s.%s at 0x%x; %s>' % ( - self.__module__, self.__class__.__name__, id(self), friendly) - - -class _Immutable(object): - """mark a ClauseElement as 'immutable' when expressions are cloned.""" - - def unique_params(self, *optionaldict, **kwargs): - raise NotImplementedError("Immutable objects do not support copying") - - def params(self, *optionaldict, **kwargs): - raise NotImplementedError("Immutable objects do not support copying") - - def _clone(self): - return self - -class Operators(object): - def __and__(self, other): - return self.operate(operators.and_, other) - - def __or__(self, other): - return self.operate(operators.or_, other) - - def __invert__(self): - return self.operate(operators.inv) - - def op(self, opstring): - def op(b): - return self.operate(operators.op, opstring, b) - return op - - def operate(self, op, *other, **kwargs): - raise NotImplementedError(str(op)) - - def reverse_operate(self, op, other, **kwargs): - raise NotImplementedError(str(op)) - -class ColumnOperators(Operators): - """Defines comparison and math operations.""" - - timetuple = None - """Hack, allows datetime objects to be compared on the LHS.""" - - def __lt__(self, other): - return self.operate(operators.lt, other) - - def __le__(self, other): - return self.operate(operators.le, other) - - __hash__ = Operators.__hash__ - - def __eq__(self, other): - return self.operate(operators.eq, other) - - def __ne__(self, other): - return self.operate(operators.ne, other) - - def __gt__(self, other): - return self.operate(operators.gt, other) - - def __ge__(self, other): - return self.operate(operators.ge, other) - - def __neg__(self): - return self.operate(operators.neg) - - def concat(self, other): - return self.operate(operators.concat_op, other) - - def like(self, other, escape=None): - return self.operate(operators.like_op, other, escape=escape) - - def ilike(self, other, escape=None): - return self.operate(operators.ilike_op, other, escape=escape) - - def in_(self, other): - return self.operate(operators.in_op, other) - - def startswith(self, other, **kwargs): - return self.operate(operators.startswith_op, other, **kwargs) - - def endswith(self, other, **kwargs): - return self.operate(operators.endswith_op, other, **kwargs) - - def contains(self, other, **kwargs): - return self.operate(operators.contains_op, other, **kwargs) - - def match(self, other, **kwargs): - return self.operate(operators.match_op, other, **kwargs) - - def desc(self): - return self.operate(operators.desc_op) - - def asc(self): - return self.operate(operators.asc_op) - - def collate(self, collation): - return self.operate(operators.collate, collation) - - def __radd__(self, other): - return self.reverse_operate(operators.add, other) - - def __rsub__(self, other): - return self.reverse_operate(operators.sub, other) - - def __rmul__(self, other): - return self.reverse_operate(operators.mul, other) - - def __rdiv__(self, other): - return self.reverse_operate(operators.div, other) - - def between(self, cleft, cright): - return self.operate(operators.between_op, cleft, cright) - - def distinct(self): - return self.operate(operators.distinct_op) - - def __add__(self, other): - return self.operate(operators.add, other) - - def __sub__(self, other): - return self.operate(operators.sub, other) - - def __mul__(self, other): - return self.operate(operators.mul, other) - - def __div__(self, other): - return self.operate(operators.div, other) - - def __mod__(self, other): - return self.operate(operators.mod, other) - - def __truediv__(self, other): - return self.operate(operators.truediv, other) - - def __rtruediv__(self, other): - return self.reverse_operate(operators.truediv, other) - -class _CompareMixin(ColumnOperators): - """Defines comparison and math operations for :class:`ClauseElement` instances.""" - - def __compare(self, op, obj, negate=None, reverse=False, **kwargs): - if obj is None or isinstance(obj, _Null): - if op == operators.eq: - return _BinaryExpression(self, null(), operators.is_, negate=operators.isnot) - elif op == operators.ne: - return _BinaryExpression(self, null(), operators.isnot, negate=operators.is_) - else: - raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL") - else: - obj = self._check_literal(op, obj) - - if reverse: - return _BinaryExpression(obj, - self, - op, - type_=sqltypes.BOOLEANTYPE, - negate=negate, modifiers=kwargs) - else: - return _BinaryExpression(self, - obj, - op, - type_=sqltypes.BOOLEANTYPE, - negate=negate, modifiers=kwargs) - - def __operate(self, op, obj, reverse=False): - obj = self._check_literal(op, obj) - - if reverse: - left, right = obj, self - else: - left, right = self, obj - - if left.type is None: - op, result_type = sqltypes.NULLTYPE._adapt_expression(op, right.type) - elif right.type is None: - op, result_type = left.type._adapt_expression(op, sqltypes.NULLTYPE) - else: - op, result_type = left.type._adapt_expression(op, right.type) - - return _BinaryExpression(left, right, op, type_=result_type) - - - # a mapping of operators with the method they use, along with their negated - # operator for comparison operators - operators = { - operators.add : (__operate,), - operators.mul : (__operate,), - operators.sub : (__operate,), - # Py2K - operators.div : (__operate,), - # end Py2K - operators.mod : (__operate,), - operators.truediv : (__operate,), - operators.lt : (__compare, operators.ge), - operators.le : (__compare, operators.gt), - operators.ne : (__compare, operators.eq), - operators.gt : (__compare, operators.le), - operators.ge : (__compare, operators.lt), - operators.eq : (__compare, operators.ne), - operators.like_op : (__compare, operators.notlike_op), - operators.ilike_op : (__compare, operators.notilike_op), - } - - def operate(self, op, *other, **kwargs): - o = _CompareMixin.operators[op] - return o[0](self, op, other[0], *o[1:], **kwargs) - - def reverse_operate(self, op, other, **kwargs): - o = _CompareMixin.operators[op] - return o[0](self, op, other, reverse=True, *o[1:], **kwargs) - - def in_(self, other): - return self._in_impl(operators.in_op, operators.notin_op, other) - - def _in_impl(self, op, negate_op, seq_or_selectable): - seq_or_selectable = _clause_element_as_expr(seq_or_selectable) - - if isinstance(seq_or_selectable, _ScalarSelect): - return self.__compare( op, seq_or_selectable, negate=negate_op) - - elif isinstance(seq_or_selectable, _SelectBaseMixin): - # TODO: if we ever want to support (x, y, z) IN (select x, y, z from table), - # we would need a multi-column version of as_scalar() to produce a multi- - # column selectable that does not export itself as a FROM clause - return self.__compare( op, seq_or_selectable.as_scalar(), negate=negate_op) - - elif isinstance(seq_or_selectable, Selectable): - return self.__compare( op, seq_or_selectable, negate=negate_op) - - # Handle non selectable arguments as sequences - args = [] - for o in seq_or_selectable: - if not _is_literal(o): - if not isinstance( o, _CompareMixin): - raise exc.InvalidRequestError( - "in() function accepts either a list of non-selectable values, " - "or a selectable: %r" % o) - else: - o = self._bind_param(op, o) - args.append(o) - - if len(args) == 0: - # Special case handling for empty IN's, behave like comparison - # against zero row selectable. We use != to build the - # contradiction as it handles NULL values appropriately, i.e. - # "not (x IN ())" should not return NULL values for x. - util.warn("The IN-predicate on \"%s\" was invoked with an empty sequence. " - "This results in a contradiction, which nonetheless can be " - "expensive to evaluate. Consider alternative strategies for " - "improved performance." % self) - - return self != self - - return self.__compare(op, ClauseList(*args).self_group(against=op), negate=negate_op) - - def __neg__(self): - return _UnaryExpression(self, operator=operators.neg) - - def startswith(self, other, escape=None): - """Produce the clause ``LIKE '%'``""" - - # use __radd__ to force string concat behavior - return self.__compare( - operators.like_op, - literal_column("'%'", type_=sqltypes.String).__radd__( - self._check_literal(operators.like_op, other) - ), - escape=escape) - - def endswith(self, other, escape=None): - """Produce the clause ``LIKE '%'``""" - - return self.__compare( - operators.like_op, - literal_column("'%'", type_=sqltypes.String) + - self._check_literal(operators.like_op, other), - escape=escape) - - def contains(self, other, escape=None): - """Produce the clause ``LIKE '%%'``""" - - return self.__compare( - operators.like_op, - literal_column("'%'", type_=sqltypes.String) + - self._check_literal(operators.like_op, other) + - literal_column("'%'", type_=sqltypes.String), - escape=escape) - - def match(self, other): - """Produce a MATCH clause, i.e. ``MATCH ''`` - - The allowed contents of ``other`` are database backend specific. - - """ - return self.__compare(operators.match_op, self._check_literal(operators.match_op, other)) - - def label(self, name): - """Produce a column label, i.e. `` AS ``. - - if 'name' is None, an anonymous label name will be generated. - - """ - return _Label(name, self, self.type) - - def desc(self): - """Produce a DESC clause, i.e. `` DESC``""" - - return desc(self) - - def asc(self): - """Produce a ASC clause, i.e. `` ASC``""" - - return asc(self) - - def distinct(self): - """Produce a DISTINCT clause, i.e. ``DISTINCT ``""" - return _UnaryExpression(self, operator=operators.distinct_op, type_=self.type) - - def between(self, cleft, cright): - """Produce a BETWEEN clause, i.e. `` BETWEEN AND ``""" - - return _BinaryExpression( - self, - ClauseList( - self._check_literal(operators.and_, cleft), - self._check_literal(operators.and_, cright), - operator=operators.and_, - group=False), - operators.between_op) - - def collate(self, collation): - """Produce a COLLATE clause, i.e. `` COLLATE utf8_bin``""" - - return collate(self, collation) - - def op(self, operator): - """produce a generic operator function. - - e.g.:: - - somecolumn.op("*")(5) - - produces:: - - somecolumn * 5 - - - :param operator: a string which will be output as the infix operator between - this :class:`ClauseElement` and the expression passed to the - generated function. - - This function can also be used to make bitwise operators explicit. For example:: - - somecolumn.op('&')(0xff) - - is a bitwise AND of the value in somecolumn. - - """ - return lambda other: self.__operate(operator, other) - - def _bind_param(self, operator, obj): - return _BindParamClause(None, obj, - _compared_to_operator=operator, - _compared_to_type=self.type, unique=True) - - def _check_literal(self, operator, other): - if isinstance(other, _BindParamClause) and \ - isinstance(other.type, sqltypes.NullType): - # TODO: perhaps we should not mutate the incoming bindparam() - # here and instead make a copy of it. this might - # be the only place that we're mutating an incoming construct. - other.type = self.type - return other - elif hasattr(other, '__clause_element__'): - return other.__clause_element__() - elif not isinstance(other, ClauseElement): - return self._bind_param(operator, other) - elif isinstance(other, (_SelectBaseMixin, Alias)): - return other.as_scalar() - else: - return other - - -class ColumnElement(ClauseElement, _CompareMixin): - """Represent an element that is usable within the "column clause" portion of a ``SELECT`` statement. - - This includes columns associated with tables, aliases, and - subqueries, expressions, function calls, SQL keywords such as - ``NULL``, literals, etc. :class:`ColumnElement` is the ultimate base - class for all such elements. - - :class:`ColumnElement` supports the ability to be a *proxy* element, - which indicates that the :class:`ColumnElement` may be associated with - a :class:`Selectable` which was derived from another :class:`Selectable`. - An example of a "derived" :class:`Selectable` is an :class:`Alias` of a - :class:`~sqlalchemy.schema.Table`. - - A :class:`ColumnElement`, by subclassing the :class:`_CompareMixin` mixin - class, provides the ability to generate new :class:`ClauseElement` - objects using Python expressions. See the :class:`_CompareMixin` - docstring for more details. - - """ - - __visit_name__ = 'column' - primary_key = False - foreign_keys = [] - quote = None - _label = None - - @property - def _select_iterable(self): - return (self, ) - - @util.memoized_property - def base_columns(self): - return util.column_set(c for c in self.proxy_set - if not hasattr(c, 'proxies')) - - @util.memoized_property - def proxy_set(self): - s = util.column_set([self]) - if hasattr(self, 'proxies'): - for c in self.proxies: - s.update(c.proxy_set) - return s - - def shares_lineage(self, othercolumn): - """Return True if the given :class:`ColumnElement` - has a common ancestor to this :class:`ColumnElement`.""" - - return bool(self.proxy_set.intersection(othercolumn.proxy_set)) - - def _make_proxy(self, selectable, name=None): - """Create a new :class:`ColumnElement` representing this - :class:`ColumnElement` as it appears in the select list of a - descending selectable. - - """ - - if name: - co = ColumnClause(name, selectable, type_=getattr(self, 'type', None)) - else: - name = str(self) - co = ColumnClause(self.anon_label, selectable, type_=getattr(self, 'type', None)) - - co.proxies = [self] - selectable.columns[name] = co - return co - - def compare(self, other, use_proxies=False, equivalents=None, **kw): - """Compare this ColumnElement to another. - - Special arguments understood: - - :param use_proxies: when True, consider two columns that - share a common base column as equivalent (i.e. shares_lineage()) - - :param equivalents: a dictionary of columns as keys mapped to sets - of columns. If the given "other" column is present in this dictionary, - if any of the columns in the correponding set() pass the comparison - test, the result is True. This is used to expand the comparison to - other columns that may be known to be equivalent to this one via - foreign key or other criterion. - - """ - to_compare = (other, ) - if equivalents and other in equivalents: - to_compare = equivalents[other].union(to_compare) - - for oth in to_compare: - if use_proxies and self.shares_lineage(oth): - return True - elif oth is self: - return True - else: - return False - - @util.memoized_property - def anon_label(self): - """provides a constant 'anonymous label' for this ColumnElement. - - This is a label() expression which will be named at compile time. - The same label() is returned each time anon_label is called so - that expressions can reference anon_label multiple times, producing - the same label name at compile time. - - the compiler uses this function automatically at compile time - for expressions that are known to be 'unnamed' like binary - expressions and function calls. - - """ - return _generated_label("%%(%d %s)s" % (id(self), getattr(self, 'name', 'anon'))) - -class ColumnCollection(util.OrderedProperties): - """An ordered dictionary that stores a list of ColumnElement - instances. - - Overrides the ``__eq__()`` method to produce SQL clauses between - sets of correlated columns. - - """ - - def __init__(self, *cols): - super(ColumnCollection, self).__init__() - self.update((c.key, c) for c in cols) - - def __str__(self): - return repr([str(c) for c in self]) - - def replace(self, column): - """add the given column to this collection, removing unaliased - versions of this column as well as existing columns with the - same key. - - e.g.:: - - t = Table('sometable', metadata, Column('col1', Integer)) - t.columns.replace(Column('col1', Integer, key='columnone')) - - will remove the original 'col1' from the collection, and add - the new column under the name 'columnname'. - - Used by schema.Column to override columns during table reflection. - - """ - if column.name in self and column.key != column.name: - other = self[column.name] - if other.name == other.key: - del self[other.name] - util.OrderedProperties.__setitem__(self, column.key, column) - - def add(self, column): - """Add a column to this collection. - - The key attribute of the column will be used as the hash key - for this dictionary. - - """ - self[column.key] = column - - def __setitem__(self, key, value): - if key in self: - # this warning is primarily to catch select() statements which - # have conflicting column names in their exported columns collection - existing = self[key] - if not existing.shares_lineage(value): - util.warn(("Column %r on table %r being replaced by another " - "column with the same key. Consider use_labels " - "for select() statements.") % (key, getattr(existing, 'table', None))) - util.OrderedProperties.__setitem__(self, key, value) - - def remove(self, column): - del self[column.key] - - def extend(self, iter): - for c in iter: - self.add(c) - - __hash__ = None - - def __eq__(self, other): - l = [] - for c in other: - for local in self: - if c.shares_lineage(local): - l.append(c==local) - return and_(*l) - - def __contains__(self, other): - if not isinstance(other, basestring): - raise exc.ArgumentError("__contains__ requires a string argument") - return util.OrderedProperties.__contains__(self, other) - - def contains_column(self, col): - # have to use a Set here, because it will compare the identity - # of the column, not just using "==" for comparison which will always return a - # "True" value (i.e. a BinaryClause...) - return col in util.column_set(self) - -class ColumnSet(util.ordered_column_set): - def contains_column(self, col): - return col in self - - def extend(self, cols): - for col in cols: - self.add(col) - - def __add__(self, other): - return list(self) + list(other) - - def __eq__(self, other): - l = [] - for c in other: - for local in self: - if c.shares_lineage(local): - l.append(c==local) - return and_(*l) - - def __hash__(self): - return hash(tuple(x for x in self)) - -class Selectable(ClauseElement): - """mark a class as being selectable""" - __visit_name__ = 'selectable' - -class FromClause(Selectable): - """Represent an element that can be used within the ``FROM`` - clause of a ``SELECT`` statement. - - """ - __visit_name__ = 'fromclause' - named_with_column = False - _hide_froms = [] - quote = None - schema = None - - def count(self, whereclause=None, **params): - """return a SELECT COUNT generated against this :class:`FromClause`.""" - - if self.primary_key: - col = list(self.primary_key)[0] - else: - col = list(self.columns)[0] - return select( - [func.count(col).label('tbl_row_count')], - whereclause, - from_obj=[self], - **params) - - def select(self, whereclause=None, **params): - """return a SELECT of this :class:`FromClause`.""" - - return select([self], whereclause, **params) - - def join(self, right, onclause=None, isouter=False): - """return a join of this :class:`FromClause` against another :class:`FromClause`.""" - - return Join(self, right, onclause, isouter) - - def outerjoin(self, right, onclause=None): - """return an outer join of this :class:`FromClause` against another :class:`FromClause`.""" - - return Join(self, right, onclause, True) - - def alias(self, name=None): - """return an alias of this :class:`FromClause`. - - For table objects, this has the effect of the table being rendered - as ``tablename AS aliasname`` in a SELECT statement. - For select objects, the effect is that of creating a named - subquery, i.e. ``(select ...) AS aliasname``. - The :func:`alias()` method is the general way to create - a "subquery" out of an existing SELECT. - - The ``name`` parameter is optional, and if left blank an - "anonymous" name will be generated at compile time, guaranteed - to be unique against other anonymous constructs used in the - same statement. - - """ - - return Alias(self, name) - - def is_derived_from(self, fromclause): - """Return True if this FromClause is 'derived' from the given FromClause. - - An example would be an Alias of a Table is derived from that Table. - - """ - return fromclause in self._cloned_set - - def replace_selectable(self, old, alias): - """replace all occurences of FromClause 'old' with the given Alias - object, returning a copy of this :class:`FromClause`. - - """ - global ClauseAdapter - if ClauseAdapter is None: - from sqlalchemy.sql.util import ClauseAdapter - return ClauseAdapter(alias).traverse(self) - - def correspond_on_equivalents(self, column, equivalents): - """Return corresponding_column for the given column, or if None - search for a match in the given dictionary. - - """ - col = self.corresponding_column(column, require_embedded=True) - if col is None and col in equivalents: - for equiv in equivalents[col]: - nc = self.corresponding_column(equiv, require_embedded=True) - if nc: - return nc - return col - - def corresponding_column(self, column, require_embedded=False): - """Given a :class:`ColumnElement`, return the exported :class:`ColumnElement` - object from this :class:`Selectable` which corresponds to that - original :class:`~sqlalchemy.schema.Column` via a common anscestor column. - - :param column: the target :class:`ColumnElement` to be matched - - :param require_embedded: only return corresponding columns for the given - :class:`ColumnElement`, if the given :class:`ColumnElement` is - actually present within a sub-element of this - :class:`FromClause`. Normally the column will match if it merely - shares a common anscestor with one of the exported columns - of this :class:`FromClause`. - - """ - # dont dig around if the column is locally present - if self.c.contains_column(column): - return column - - col, intersect = None, None - target_set = column.proxy_set - cols = self.c - for c in cols: - i = target_set.intersection(itertools.chain(*[p._cloned_set for p in c.proxy_set])) - - if i and \ - (not require_embedded or c.proxy_set.issuperset(target_set)): - - if col is None: - # no corresponding column yet, pick this one. - col, intersect = c, i - elif len(i) > len(intersect): - # 'c' has a larger field of correspondence than 'col'. - # i.e. selectable.c.a1_x->a1.c.x->table.c.x matches - # a1.c.x->table.c.x better than - # selectable.c.x->table.c.x does. - col, intersect = c, i - elif i == intersect: - # they have the same field of correspondence. - # see which proxy_set has fewer columns in it, which indicates - # a closer relationship with the root column. Also take into - # account the "weight" attribute which CompoundSelect() uses to - # give higher precedence to columns based on vertical position - # in the compound statement, and discard columns that have no - # reference to the target column (also occurs with - # CompoundSelect) - col_distance = util.reduce(operator.add, - [sc._annotations.get('weight', 1) - for sc in col.proxy_set - if sc.shares_lineage(column)] - ) - c_distance = util.reduce(operator.add, - [sc._annotations.get('weight', 1) - for sc in c.proxy_set - if sc.shares_lineage(column)] - ) - if c_distance < col_distance: - col, intersect = c, i - return col - - @property - def description(self): - """a brief description of this FromClause. - - Used primarily for error message formatting. - - """ - return getattr(self, 'name', self.__class__.__name__ + " object") - - def _reset_exported(self): - """delete memoized collections when a FromClause is cloned.""" - - for attr in ('_columns', '_primary_key' '_foreign_keys', 'locate_all_froms'): - self.__dict__.pop(attr, None) - - @util.memoized_property - def _columns(self): - """Return the collection of Column objects contained by this FromClause.""" - - self._export_columns() - return self._columns - - @util.memoized_property - def _primary_key(self): - """Return the collection of Column objects which comprise the - primary key of this FromClause.""" - - self._export_columns() - return self._primary_key - - @util.memoized_property - def _foreign_keys(self): - """Return the collection of ForeignKey objects which this - FromClause references.""" - - self._export_columns() - return self._foreign_keys - - columns = property(attrgetter('_columns'), doc=_columns.__doc__) - primary_key = property( - attrgetter('_primary_key'), - doc=_primary_key.__doc__) - foreign_keys = property( - attrgetter('_foreign_keys'), - doc=_foreign_keys.__doc__) - - # synonyms for 'columns' - c = _select_iterable = property(attrgetter('columns'), doc=_columns.__doc__) - - def _export_columns(self): - """Initialize column collections.""" - - self._columns = ColumnCollection() - self._primary_key = ColumnSet() - self._foreign_keys = set() - self._populate_column_collection() - - def _populate_column_collection(self): - pass - -class _BindParamClause(ColumnElement): - """Represent a bind parameter. - - Public constructor is the :func:`bindparam()` function. - - """ - - __visit_name__ = 'bindparam' - quote = None - - def __init__(self, key, value, type_=None, unique=False, - isoutparam=False, required=False, - _compared_to_operator=None, - _compared_to_type=None): - """Construct a _BindParamClause. - - key - the key for this bind param. Will be used in the generated - SQL statement for dialects that use named parameters. This - value may be modified when part of a compilation operation, - if other :class:`_BindParamClause` objects exist with the same - key, or if its length is too long and truncation is - required. - - value - Initial value for this bind param. This value may be - overridden by the dictionary of parameters sent to statement - compilation/execution. - - type\_ - A ``TypeEngine`` object that will be used to pre-process the - value corresponding to this :class:`_BindParamClause` at - execution time. - - unique - if True, the key name of this BindParamClause will be - modified if another :class:`_BindParamClause` of the same name - already has been located within the containing - :class:`ClauseElement`. - - required - a value is required at execution time. - - isoutparam - if True, the parameter should be treated like a stored procedure "OUT" - parameter. - - """ - if unique: - self.key = _generated_label("%%(%d %s)s" % (id(self), key or 'param')) - else: - self.key = key or _generated_label("%%(%d param)s" % id(self)) - self._orig_key = key or 'param' - self.unique = unique - self.value = value - self.isoutparam = isoutparam - self.required = required - - if type_ is None: - if _compared_to_type is not None: - self.type = _compared_to_type._coerce_compared_value(_compared_to_operator, value) - else: - self.type = sqltypes.type_map.get(type(value), sqltypes.NULLTYPE) - elif isinstance(type_, type): - self.type = type_() - else: - self.type = type_ - - def _clone(self): - c = ClauseElement._clone(self) - if self.unique: - c.key = _generated_label("%%(%d %s)s" % (id(c), c._orig_key or 'param')) - return c - - def _convert_to_unique(self): - if not self.unique: - self.unique = True - self.key = _generated_label("%%(%d %s)s" % (id(self), - self._orig_key or 'param')) - - def bind_processor(self, dialect): - return self.type.dialect_impl(dialect).bind_processor(dialect) - - def compare(self, other, **kw): - """Compare this :class:`_BindParamClause` to the given clause.""" - - return isinstance(other, _BindParamClause) and \ - self.type._compare_type_affinity(other.type) and \ - self.value == other.value - - def __getstate__(self): - """execute a deferred value for serialization purposes.""" - - d = self.__dict__.copy() - v = self.value - if util.callable(v): - v = v() - d['value'] = v - return d - - def __repr__(self): - return "_BindParamClause(%r, %r, type_=%r)" % ( - self.key, self.value, self.type - ) - -class _TypeClause(ClauseElement): - """Handle a type keyword in a SQL statement. - - Used by the ``Case`` statement. - - """ - - __visit_name__ = 'typeclause' - - def __init__(self, type): - self.type = type - - -class _Generative(object): - """Allow a ClauseElement to generate itself via the - @_generative decorator. - - """ - - def _generate(self): - s = self.__class__.__new__(self.__class__) - s.__dict__ = self.__dict__.copy() - return s - - -class Executable(_Generative): - """Mark a ClauseElement as supporting execution. - - :class:`Executable` is a superclass for all "statement" types - of objects, including :func:`select`, :func:`delete`, :func:`update`, - :func:`insert`, :func:`text`. - - """ - - supports_execution = True - _execution_options = util.frozendict() - - @_generative - def execution_options(self, **kw): - """ Set non-SQL options for the statement which take effect during execution. - - Current options include: - - * autocommit - when True, a COMMIT will be invoked after execution - when executed in 'autocommit' mode, i.e. when an explicit transaction - is not begun on the connection. Note that DBAPI connections by - default are always in a transaction - SQLAlchemy uses rules applied - to different kinds of statements to determine if COMMIT will be invoked - in order to provide its "autocommit" feature. Typically, all - INSERT/UPDATE/DELETE statements as well as CREATE/DROP statements - have autocommit behavior enabled; SELECT constructs do not. Use this - option when invokving a SELECT or other specific SQL construct - where COMMIT is desired (typically when calling stored procedures - and such). - - * stream_results - indicate to the dialect that results should be - "streamed" and not pre-buffered, if possible. This is a limitation - of many DBAPIs. The flag is currently understood only by the - psycopg2 dialect. - - See also: - - :meth:`sqlalchemy.engine.base.Connection.execution_options()` - - :meth:`sqlalchemy.orm.query.Query.execution_options()` - - """ - self._execution_options = self._execution_options.union(kw) - -# legacy, some outside users may be calling this + 'Alias', 'any_', 'all_', 'ClauseElement', 'ColumnCollection', 'ColumnElement', + 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', 'Lateral', + 'Select', + 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', 'between', + 'bindparam', 'case', 'cast', 'column', 'delete', 'desc', 'distinct', + 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier', + 'collate', 'insert', 'intersect', 'intersect_all', 'join', 'label', + 'lateral', 'literal', 'literal_column', 'not_', 'null', 'nullsfirst', + 'nullslast', + 'or_', 'outparam', 'outerjoin', 'over', 'select', 'subquery', + 'table', 'text', + 'tuple_', 'type_coerce', 'union', 'union_all', 'update', 'within_group', + 'TableSample', 'tablesample'] + + +from .visitors import Visitable +from .functions import func, modifier, FunctionElement, Function +from ..util.langhelpers import public_factory +from .elements import ClauseElement, ColumnElement,\ + BindParameter, CollectionAggregate, UnaryExpression, BooleanClauseList, \ + Label, Cast, Case, ColumnClause, TextClause, Over, Null, \ + True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \ + Grouping, WithinGroup, not_, \ + collate, literal_column, between,\ + literal, outparam, TypeCoerce, ClauseList, FunctionFilter + +from .elements import SavepointClause, RollbackToSavepointClause, \ + ReleaseSavepointClause + +from .base import ColumnCollection, Generative, Executable, \ + PARSE_AUTOCOMMIT + +from .selectable import Alias, Join, Select, Selectable, TableClause, \ + CompoundSelect, CTE, FromClause, FromGrouping, Lateral, SelectBase, \ + alias, GenerativeSelect, subquery, HasCTE, HasPrefixes, HasSuffixes, \ + lateral, Exists, ScalarSelect, TextAsFrom, TableSample, tablesample + + +from .dml import Insert, Update, Delete, UpdateBase, ValuesBase + +# factory functions - these pull class-bound constructors and classmethods +# from SQL elements and selectables into public functions. This allows +# the functions to be available in the sqlalchemy.sql.* namespace and +# to be auto-cross-documenting from the function to the class itself. + +all_ = public_factory(CollectionAggregate._create_all, ".expression.all_") +any_ = public_factory(CollectionAggregate._create_any, ".expression.any_") +and_ = public_factory(BooleanClauseList.and_, ".expression.and_") +or_ = public_factory(BooleanClauseList.or_, ".expression.or_") +bindparam = public_factory(BindParameter, ".expression.bindparam") +select = public_factory(Select, ".expression.select") +text = public_factory(TextClause._create_text, ".expression.text") +table = public_factory(TableClause, ".expression.table") +column = public_factory(ColumnClause, ".expression.column") +over = public_factory(Over, ".expression.over") +within_group = public_factory(WithinGroup, ".expression.within_group") +label = public_factory(Label, ".expression.label") +case = public_factory(Case, ".expression.case") +cast = public_factory(Cast, ".expression.cast") +extract = public_factory(Extract, ".expression.extract") +tuple_ = public_factory(Tuple, ".expression.tuple_") +except_ = public_factory(CompoundSelect._create_except, ".expression.except_") +except_all = public_factory( + CompoundSelect._create_except_all, ".expression.except_all") +intersect = public_factory( + CompoundSelect._create_intersect, ".expression.intersect") +intersect_all = public_factory( + CompoundSelect._create_intersect_all, ".expression.intersect_all") +union = public_factory(CompoundSelect._create_union, ".expression.union") +union_all = public_factory( + CompoundSelect._create_union_all, ".expression.union_all") +exists = public_factory(Exists, ".expression.exists") +nullsfirst = public_factory( + UnaryExpression._create_nullsfirst, ".expression.nullsfirst") +nullslast = public_factory( + UnaryExpression._create_nullslast, ".expression.nullslast") +asc = public_factory(UnaryExpression._create_asc, ".expression.asc") +desc = public_factory(UnaryExpression._create_desc, ".expression.desc") +distinct = public_factory( + UnaryExpression._create_distinct, ".expression.distinct") +type_coerce = public_factory(TypeCoerce, ".expression.type_coerce") +true = public_factory(True_._instance, ".expression.true") +false = public_factory(False_._instance, ".expression.false") +null = public_factory(Null._instance, ".expression.null") +join = public_factory(Join._create_join, ".expression.join") +outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin") +insert = public_factory(Insert, ".expression.insert") +update = public_factory(Update, ".expression.update") +delete = public_factory(Delete, ".expression.delete") +funcfilter = public_factory( + FunctionFilter, ".expression.funcfilter") + + +# internal functions still being called from tests and the ORM, +# these might be better off in some other namespace +from .base import _from_objects +from .elements import _literal_as_text, _clause_element_as_expr,\ + _is_column, _labeled, _only_column_elements, _string_or_unprintable, \ + _truncated_label, _clone, _cloned_difference, _cloned_intersection,\ + _column_as_key, _literal_as_binds, _select_iterables, \ + _corresponding_column_or_error, _literal_as_label_reference, \ + _expression_literal_as_text +from .selectable import _interpret_as_from + + +# old names for compatibility _Executable = Executable - -class _TextClause(Executable, ClauseElement): - """Represent a literal SQL text fragment. - - Public constructor is the :func:`text()` function. - - """ - - __visit_name__ = 'textclause' - - _bind_params_regex = re.compile(r'(? RIGHT``.""" - - __visit_name__ = 'binary' - - def __init__(self, left, right, operator, type_=None, negate=None, modifiers=None): - self.left = _literal_as_text(left).self_group(against=operator) - self.right = _literal_as_text(right).self_group(against=operator) - self.operator = operator - self.type = sqltypes.to_instance(type_) - self.negate = negate - if modifiers is None: - self.modifiers = {} - else: - self.modifiers = modifiers - - def __nonzero__(self): - try: - return self.operator(hash(self.left), hash(self.right)) - except: - raise TypeError("Boolean value of this clause is not defined") - - @property - def _from_objects(self): - return self.left._from_objects + self.right._from_objects - - def _copy_internals(self, clone=_clone): - self.left = clone(self.left) - self.right = clone(self.right) - - def get_children(self, **kwargs): - return self.left, self.right - - def compare(self, other, **kw): - """Compare this :class:`_BinaryExpression` against the - given :class:`_BinaryExpression`.""" - - return ( - isinstance(other, _BinaryExpression) and - self.operator == other.operator and - ( - self.left.compare(other.left, **kw) and - self.right.compare(other.right, **kw) or - ( - operators.is_commutative(self.operator) and - self.left.compare(other.right, **kw) and - self.right.compare(other.left, **kw) - ) - ) - ) - - def self_group(self, against=None): - # use small/large defaults for comparison so that unknown - # operators are always parenthesized - if self.operator is not against and operators.is_precedent(self.operator, against): - return _Grouping(self) - else: - return self - - def _negate(self): - if self.negate is not None: - return _BinaryExpression( - self.left, - self.right, - self.negate, - negate=self.operator, - type_=sqltypes.BOOLEANTYPE, - modifiers=self.modifiers) - else: - return super(_BinaryExpression, self)._negate() - -class _Exists(_UnaryExpression): - __visit_name__ = _UnaryExpression.__visit_name__ - _from_objects = [] - - def __init__(self, *args, **kwargs): - if args and isinstance(args[0], (_SelectBaseMixin, _ScalarSelect)): - s = args[0] - else: - if not args: - args = ([literal_column('*')],) - s = select(*args, **kwargs).as_scalar().self_group() - - _UnaryExpression.__init__(self, s, operator=operators.exists, type_=sqltypes.Boolean) - - def select(self, whereclause=None, **params): - return select([self], whereclause, **params) - - def correlate(self, fromclause): - e = self._clone() - e.element = self.element.correlate(fromclause).self_group() - return e - - def select_from(self, clause): - """return a new exists() construct with the given expression set as its FROM - clause. - - """ - e = self._clone() - e.element = self.element.select_from(clause).self_group() - return e - - def where(self, clause): - """return a new exists() construct with the given expression added to its WHERE - clause, joined to the existing clause via AND, if any. - - """ - e = self._clone() - e.element = self.element.where(clause).self_group() - return e - -class Join(FromClause): - """represent a ``JOIN`` construct between two :class:`FromClause` elements. - - The public constructor function for :class:`Join` is the module-level - :func:`join()` function, as well as the :func:`join()` method available - off all :class:`FromClause` subclasses. - - """ - __visit_name__ = 'join' - - def __init__(self, left, right, onclause=None, isouter=False): - self.left = _literal_as_text(left) - self.right = _literal_as_text(right).self_group() - - if onclause is None: - self.onclause = self._match_primaries(self.left, self.right) - else: - self.onclause = onclause - - self.isouter = isouter - self.__folded_equivalents = None - - @property - def description(self): - return "Join object on %s(%d) and %s(%d)" % ( - self.left.description, - id(self.left), - self.right.description, - id(self.right)) - - def is_derived_from(self, fromclause): - return fromclause is self or \ - self.left.is_derived_from(fromclause) or\ - self.right.is_derived_from(fromclause) - - def self_group(self, against=None): - return _FromGrouping(self) - - def _populate_column_collection(self): - columns = [c for c in self.left.columns] + [c for c in self.right.columns] - - global sql_util - if not sql_util: - from sqlalchemy.sql import util as sql_util - self._primary_key.extend(sql_util.reduce_columns( - (c for c in columns if c.primary_key), self.onclause)) - self._columns.update((col._label, col) for col in columns) - self._foreign_keys.update(itertools.chain(*[col.foreign_keys for col in columns])) - - def _copy_internals(self, clone=_clone): - self._reset_exported() - self.left = clone(self.left) - self.right = clone(self.right) - self.onclause = clone(self.onclause) - self.__folded_equivalents = None - - def get_children(self, **kwargs): - return self.left, self.right, self.onclause - - def _match_primaries(self, left, right): - global sql_util - if not sql_util: - from sqlalchemy.sql import util as sql_util - if isinstance(left, Join): - left_right = left.right - else: - left_right = None - return sql_util.join_condition(left, right, a_subset=left_right) - - def select(self, whereclause=None, fold_equivalents=False, **kwargs): - """Create a :class:`Select` from this :class:`Join`. - - :param whereclause: the WHERE criterion that will be sent to - the :func:`select()` function - - :param fold_equivalents: based on the join criterion of this - :class:`Join`, do not include - repeat column names in the column list of the resulting - select, for columns that are calculated to be "equivalent" - based on the join criterion of this :class:`Join`. This will - recursively apply to any joins directly nested by this one - as well. This flag is specific to a particular use case - by the ORM and is deprecated as of 0.6. - - :param \**kwargs: all other kwargs are sent to the - underlying :func:`select()` function. - - """ - if fold_equivalents: - global sql_util - if not sql_util: - from sqlalchemy.sql import util as sql_util - util.warn_deprecated("fold_equivalents is deprecated.") - collist = sql_util.folded_equivalents(self) - else: - collist = [self.left, self.right] - - return select(collist, whereclause, from_obj=[self], **kwargs) - - @property - def bind(self): - return self.left.bind or self.right.bind - - def alias(self, name=None): - """Create a :class:`Select` out of this :class:`Join` clause and return an :class:`Alias` of it. - - The :class:`Select` is not correlating. - - """ - return self.select(use_labels=True, correlate=False).alias(name) - - @property - def _hide_froms(self): - return itertools.chain(*[_from_objects(x.left, x.right) for x in self._cloned_set]) - - @property - def _from_objects(self): - return [self] + \ - self.onclause._from_objects + \ - self.left._from_objects + \ - self.right._from_objects - -class Alias(FromClause): - """Represents an table or selectable alias (AS). - - Represents an alias, as typically applied to any table or - sub-select within a SQL statement using the ``AS`` keyword (or - without the keyword on certain databases such as Oracle). - - This object is constructed from the :func:`alias()` module level - function as well as the :func:`alias()` method available on all - :class:`FromClause` subclasses. - - """ - - __visit_name__ = 'alias' - named_with_column = True - - def __init__(self, selectable, alias=None): - baseselectable = selectable - while isinstance(baseselectable, Alias): - baseselectable = baseselectable.element - self.original = baseselectable - self.supports_execution = baseselectable.supports_execution - if self.supports_execution: - self._execution_options = baseselectable._execution_options - self.element = selectable - if alias is None: - if self.original.named_with_column: - alias = getattr(self.original, 'name', None) - alias = _generated_label('%%(%d %s)s' % (id(self), alias or 'anon')) - self.name = alias - - @property - def description(self): - # Py3K - #return self.name - # Py2K - return self.name.encode('ascii', 'backslashreplace') - # end Py2K - - def as_scalar(self): - try: - return self.element.as_scalar() - except AttributeError: - raise AttributeError("Element %s does not support 'as_scalar()'" % self.element) - - def is_derived_from(self, fromclause): - if fromclause in self._cloned_set: - return True - return self.element.is_derived_from(fromclause) - - def _populate_column_collection(self): - for col in self.element.columns: - col._make_proxy(self) - - def _copy_internals(self, clone=_clone): - self._reset_exported() - self.element = _clone(self.element) - baseselectable = self.element - while isinstance(baseselectable, Alias): - baseselectable = baseselectable.element - self.original = baseselectable - - def get_children(self, column_collections=True, aliased_selectables=True, **kwargs): - if column_collections: - for c in self.c: - yield c - if aliased_selectables: - yield self.element - - @property - def _from_objects(self): - return [self] - - @property - def bind(self): - return self.element.bind - - -class _Grouping(ColumnElement): - """Represent a grouping within a column expression""" - - __visit_name__ = 'grouping' - - def __init__(self, element): - self.element = element - self.type = getattr(element, 'type', None) - - @property - def _label(self): - return getattr(self.element, '_label', None) or self.anon_label - - def _copy_internals(self, clone=_clone): - self.element = clone(self.element) - - def get_children(self, **kwargs): - return self.element, - - @property - def _from_objects(self): - return self.element._from_objects - - def __getattr__(self, attr): - return getattr(self.element, attr) - - def __getstate__(self): - return {'element':self.element, 'type':self.type} - - def __setstate__(self, state): - self.element = state['element'] - self.type = state['type'] - -class _FromGrouping(FromClause): - """Represent a grouping of a FROM clause""" - __visit_name__ = 'grouping' - - def __init__(self, element): - self.element = element - - @property - def columns(self): - return self.element.columns - - @property - def _hide_froms(self): - return self.element._hide_froms - - def get_children(self, **kwargs): - return self.element, - - def _copy_internals(self, clone=_clone): - self.element = clone(self.element) - - @property - def _from_objects(self): - return self.element._from_objects - - def __getattr__(self, attr): - return getattr(self.element, attr) - - def __getstate__(self): - return {'element':self.element} - - def __setstate__(self, state): - self.element = state['element'] - -class _Label(ColumnElement): - """Represents a column label (AS). - - Represent a label, as typically applied to any column-level - element using the ``AS`` sql keyword. - - This object is constructed from the :func:`label()` module level - function as well as the :func:`label()` method available on all - :class:`ColumnElement` subclasses. - - """ - - __visit_name__ = 'label' - - def __init__(self, name, element, type_=None): - while isinstance(element, _Label): - element = element.element - self.name = self.key = self._label = name or \ - _generated_label("%%(%d %s)s" % ( - id(self), getattr(element, 'name', 'anon')) - ) - self._element = element - self._type = type_ - self.quote = element.quote - - @util.memoized_property - def type(self): - return sqltypes.to_instance( - self._type or getattr(self._element, 'type', None) - ) - - @util.memoized_property - def element(self): - return self._element.self_group(against=operators.as_) - - def _proxy_attr(name): - get = attrgetter(name) - def attr(self): - return get(self.element) - return property(attr) - - proxies = _proxy_attr('proxies') - base_columns = _proxy_attr('base_columns') - proxy_set = _proxy_attr('proxy_set') - primary_key = _proxy_attr('primary_key') - foreign_keys = _proxy_attr('foreign_keys') - - def get_children(self, **kwargs): - return self.element, - - def _copy_internals(self, clone=_clone): - self.element = clone(self.element) - - @property - def _from_objects(self): - return self.element._from_objects - - def _make_proxy(self, selectable, name = None): - if isinstance(self.element, (Selectable, ColumnElement)): - e = self.element._make_proxy(selectable, name=self.name) - else: - e = column(self.name)._make_proxy(selectable=selectable) - e.proxies.append(self) - return e - -class ColumnClause(_Immutable, ColumnElement): - """Represents a generic column expression from any textual string. - - This includes columns associated with tables, aliases and select - statements, but also any arbitrary text. May or may not be bound - to an underlying :class:`Selectable`. :class:`ColumnClause` is usually - created publically via the :func:`column()` function or the - :func:`literal_column()` function. - - text - the text of the element. - - selectable - parent selectable. - - type - ``TypeEngine`` object which can associate this :class:`ColumnClause` - with a type. - - is_literal - if True, the :class:`ColumnClause` is assumed to be an exact - expression that will be delivered to the output with no quoting - rules applied regardless of case sensitive settings. the - :func:`literal_column()` function is usually used to create such a - :class:`ColumnClause`. - - """ - __visit_name__ = 'column' - - onupdate = default = server_default = server_onupdate = None - - def __init__(self, text, selectable=None, type_=None, is_literal=False): - self.key = self.name = text - self.table = selectable - self.type = sqltypes.to_instance(type_) - self.is_literal = is_literal - - @util.memoized_property - def description(self): - # Py3K - #return self.name - # Py2K - return self.name.encode('ascii', 'backslashreplace') - # end Py2K - - @util.memoized_property - def _label(self): - if self.is_literal: - return None - - elif self.table is not None and self.table.named_with_column: - if getattr(self.table, 'schema', None): - label = self.table.schema.replace('.', '_') + "_" + \ - _escape_for_generated(self.table.name) + "_" + \ - _escape_for_generated(self.name) - else: - label = _escape_for_generated(self.table.name) + "_" + \ - _escape_for_generated(self.name) - - return _generated_label(label) - - else: - return self.name - - def label(self, name): - if name is None: - return self - else: - return super(ColumnClause, self).label(name) - - @property - def _from_objects(self): - if self.table is not None: - return [self.table] - else: - return [] - - def _bind_param(self, operator, obj): - return _BindParamClause(self.name, obj, _compared_to_operator=operator, - _compared_to_type=self.type, unique=True) - - def _make_proxy(self, selectable, name=None, attach=True): - # propagate the "is_literal" flag only if we are keeping our name, - # otherwise its considered to be a label - is_literal = self.is_literal and (name is None or name == self.name) - c = ColumnClause( - name or self.name, - selectable=selectable, - type_=self.type, - is_literal=is_literal - ) - c.proxies = [self] - if attach: - selectable.columns[c.name] = c - return c - -class TableClause(_Immutable, FromClause): - """Represents a "table" construct. - - Note that this represents tables only as another syntactical - construct within SQL expressions; it does not provide schema-level - functionality. - - """ - - __visit_name__ = 'table' - - named_with_column = True - - def __init__(self, name, *columns): - super(TableClause, self).__init__() - self.name = self.fullname = name - self._columns = ColumnCollection() - self._primary_key = ColumnSet() - self._foreign_keys = set() - for c in columns: - self.append_column(c) - - def _export_columns(self): - raise NotImplementedError() - - @util.memoized_property - def description(self): - # Py3K - #return self.name - # Py2K - return self.name.encode('ascii', 'backslashreplace') - # end Py2K - - def append_column(self, c): - self._columns[c.name] = c - c.table = self - - def get_children(self, column_collections=True, **kwargs): - if column_collections: - return [c for c in self.c] - else: - return [] - - def count(self, whereclause=None, **params): - """return a SELECT COUNT generated against this :class:`TableClause`.""" - - if self.primary_key: - col = list(self.primary_key)[0] - else: - col = list(self.columns)[0] - return select( - [func.count(col).label('tbl_row_count')], - whereclause, - from_obj=[self], - **params) - - def insert(self, values=None, inline=False, **kwargs): - """Generate an :func:`insert()` construct.""" - - return insert(self, values=values, inline=inline, **kwargs) - - def update(self, whereclause=None, values=None, inline=False, **kwargs): - """Generate an :func:`update()` construct.""" - - return update(self, whereclause=whereclause, - values=values, inline=inline, **kwargs) - - def delete(self, whereclause=None, **kwargs): - """Generate a :func:`delete()` construct.""" - - return delete(self, whereclause, **kwargs) - - @property - def _from_objects(self): - return [self] - -class _SelectBaseMixin(Executable): - """Base class for :class:`Select` and ``CompoundSelects``.""" - - def __init__(self, - use_labels=False, - for_update=False, - limit=None, - offset=None, - order_by=None, - group_by=None, - bind=None, - autocommit=None): - self.use_labels = use_labels - self.for_update = for_update - if autocommit is not None: - util.warn_deprecated("autocommit on select() is deprecated. " - "Use .execution_options(autocommit=True)") - self._execution_options = self._execution_options.union({'autocommit':autocommit}) - self._limit = limit - self._offset = offset - self._bind = bind - - self._order_by_clause = ClauseList(*util.to_list(order_by) or []) - self._group_by_clause = ClauseList(*util.to_list(group_by) or []) - - def as_scalar(self): - """return a 'scalar' representation of this selectable, which can be - used as a column expression. - - Typically, a select statement which has only one column in its columns - clause is eligible to be used as a scalar expression. - - The returned object is an instance of - :class:`_ScalarSelect`. - - """ - return _ScalarSelect(self) - - @_generative - def apply_labels(self): - """return a new selectable with the 'use_labels' flag set to True. - - This will result in column expressions being generated using labels - against their table name, such as "SELECT somecolumn AS - tablename_somecolumn". This allows selectables which contain multiple - FROM clauses to produce a unique set of column names regardless of - name conflicts among the individual FROM clauses. - - """ - self.use_labels = True - - def label(self, name): - """return a 'scalar' representation of this selectable, embedded as a - subquery with a label. - - See also ``as_scalar()``. - - """ - return self.as_scalar().label(name) - - @_generative - @util.deprecated(message="autocommit() is deprecated. " - "Use .execution_options(autocommit=True)") - def autocommit(self): - """return a new selectable with the 'autocommit' flag set to True.""" - - self._execution_options = self._execution_options.union({'autocommit':True}) - - def _generate(self): - """Override the default _generate() method to also clear out exported collections.""" - - s = self.__class__.__new__(self.__class__) - s.__dict__ = self.__dict__.copy() - s._reset_exported() - return s - - @_generative - def limit(self, limit): - """return a new selectable with the given LIMIT criterion applied.""" - - self._limit = limit - - @_generative - def offset(self, offset): - """return a new selectable with the given OFFSET criterion applied.""" - - self._offset = offset - - @_generative - def order_by(self, *clauses): - """return a new selectable with the given list of ORDER BY criterion applied. - - The criterion will be appended to any pre-existing ORDER BY criterion. - - """ - self.append_order_by(*clauses) - - @_generative - def group_by(self, *clauses): - """return a new selectable with the given list of GROUP BY criterion applied. - - The criterion will be appended to any pre-existing GROUP BY criterion. - - """ - self.append_group_by(*clauses) - - def append_order_by(self, *clauses): - """Append the given ORDER BY criterion applied to this selectable. - - The criterion will be appended to any pre-existing ORDER BY criterion. - - """ - if len(clauses) == 1 and clauses[0] is None: - self._order_by_clause = ClauseList() - else: - if getattr(self, '_order_by_clause', None) is not None: - clauses = list(self._order_by_clause) + list(clauses) - self._order_by_clause = ClauseList(*clauses) - - def append_group_by(self, *clauses): - """Append the given GROUP BY criterion applied to this selectable. - - The criterion will be appended to any pre-existing GROUP BY criterion. - - """ - if len(clauses) == 1 and clauses[0] is None: - self._group_by_clause = ClauseList() - else: - if getattr(self, '_group_by_clause', None) is not None: - clauses = list(self._group_by_clause) + list(clauses) - self._group_by_clause = ClauseList(*clauses) - - @property - def _from_objects(self): - return [self] - - -class _ScalarSelect(_Grouping): - _from_objects = [] - - def __init__(self, element): - self.element = element - cols = list(element.c) - self.type = cols[0].type - - @property - def columns(self): - raise exc.InvalidRequestError("Scalar Select expression has no columns; " - "use this object directly within a column-level expression.") - c = columns - - def self_group(self, **kwargs): - return self - - def _make_proxy(self, selectable, name): - return list(self.inner_columns)[0]._make_proxy(selectable, name) - -class CompoundSelect(_SelectBaseMixin, FromClause): - """Forms the basis of ``UNION``, ``UNION ALL``, and other - SELECT-based set operations.""" - - __visit_name__ = 'compound_select' - - UNION = util.symbol('UNION') - UNION_ALL = util.symbol('UNION ALL') - EXCEPT = util.symbol('EXCEPT') - EXCEPT_ALL = util.symbol('EXCEPT ALL') - INTERSECT = util.symbol('INTERSECT') - INTERSECT_ALL = util.symbol('INTERSECT ALL') - - def __init__(self, keyword, *selects, **kwargs): - self._should_correlate = kwargs.pop('correlate', False) - self.keyword = keyword - self.selects = [] - - numcols = None - - # some DBs do not like ORDER BY in the inner queries of a UNION, etc. - for n, s in enumerate(selects): - s = _clause_element_as_expr(s) - - if not numcols: - numcols = len(s.c) - elif len(s.c) != numcols: - raise exc.ArgumentError( - "All selectables passed to CompoundSelect must " - "have identical numbers of columns; select #%d has %d columns," - " select #%d has %d" % - (1, len(self.selects[0].c), n+1, len(s.c)) - ) - - self.selects.append(s.self_group(self)) - - _SelectBaseMixin.__init__(self, **kwargs) - - def self_group(self, against=None): - return _FromGrouping(self) - - def is_derived_from(self, fromclause): - for s in self.selects: - if s.is_derived_from(fromclause): - return True - return False - - def _populate_column_collection(self): - for cols in zip(*[s.c for s in self.selects]): - # this is a slightly hacky thing - the union exports a column that - # resembles just that of the *first* selectable. to get at a "composite" column, - # particularly foreign keys, you have to dig through the proxies collection - # which we generate below. We may want to improve upon this, - # such as perhaps _make_proxy can accept a list of other columns that - # are "shared" - schema.column can then copy all the ForeignKeys in. - # this would allow the union() to have all those fks too. - proxy = cols[0]._make_proxy( - self, name=self.use_labels and cols[0]._label or None) - - # hand-construct the "proxies" collection to include all derived columns - # place a 'weight' annotation corresponding to how low in the list of - # select()s the column occurs, so that the corresponding_column() operation - # can resolve conflicts - proxy.proxies = [c._annotate({'weight':i + 1}) for i, c in enumerate(cols)] - - def _copy_internals(self, clone=_clone): - self._reset_exported() - self.selects = [clone(s) for s in self.selects] - if hasattr(self, '_col_map'): - del self._col_map - for attr in ('_order_by_clause', '_group_by_clause'): - if getattr(self, attr) is not None: - setattr(self, attr, clone(getattr(self, attr))) - - def get_children(self, column_collections=True, **kwargs): - return (column_collections and list(self.c) or []) + \ - [self._order_by_clause, self._group_by_clause] + list(self.selects) - - def bind(self): - if self._bind: - return self._bind - for s in self.selects: - e = s.bind - if e: - return e - else: - return None - def _set_bind(self, bind): - self._bind = bind - bind = property(bind, _set_bind) - -class Select(_SelectBaseMixin, FromClause): - """Represents a ``SELECT`` statement. - - Select statements support appendable clauses, as well as the - ability to execute themselves and return a result set. - - """ - - __visit_name__ = 'select' - - _prefixes = () - _hints = util.frozendict() - - def __init__(self, - columns, - whereclause=None, - from_obj=None, - distinct=False, - having=None, - correlate=True, - prefixes=None, - **kwargs): - """Construct a Select object. - - The public constructor for Select is the - :func:`select` function; see that function for - argument descriptions. - - Additional generative and mutator methods are available on the - :class:`_SelectBaseMixin` superclass. - - """ - self._should_correlate = correlate - self._distinct = distinct - - self._correlate = set() - self._froms = util.OrderedSet() - - try: - cols_present = bool(columns) - except TypeError: - raise exc.ArgumentError("columns argument to select() must " - "be a Python list or other iterable") - - if cols_present: - self._raw_columns = [] - for c in columns: - c = _literal_as_column(c) - if isinstance(c, _ScalarSelect): - c = c.self_group(against=operators.comma_op) - self._raw_columns.append(c) - - self._froms.update(_from_objects(*self._raw_columns)) - else: - self._raw_columns = [] - - if whereclause is not None: - self._whereclause = _literal_as_text(whereclause) - self._froms.update(_from_objects(self._whereclause)) - else: - self._whereclause = None - - if from_obj is not None: - for f in util.to_list(from_obj): - if _is_literal(f): - self._froms.add(_TextClause(f)) - else: - self._froms.add(f) - - if having is not None: - self._having = _literal_as_text(having) - else: - self._having = None - - if prefixes: - self._prefixes = tuple([_literal_as_text(p) for p in prefixes]) - - _SelectBaseMixin.__init__(self, **kwargs) - - def _get_display_froms(self, existing_froms=None): - """Return the full list of 'from' clauses to be displayed. - - Takes into account a set of existing froms which may be - rendered in the FROM clause of enclosing selects; this Select - may want to leave those absent if it is automatically - correlating. - - """ - froms = self._froms - - toremove = itertools.chain(*[f._hide_froms for f in froms]) - if toremove: - froms = froms.difference(toremove) - - if len(froms) > 1 or self._correlate: - if self._correlate: - froms = froms.difference(_cloned_intersection(froms, self._correlate)) - - if self._should_correlate and existing_froms: - froms = froms.difference(_cloned_intersection(froms, existing_froms)) - - if not len(froms): - raise exc.InvalidRequestError( - "Select statement '%s' returned no FROM clauses " - "due to auto-correlation; specify correlate() " - "to control correlation manually." % self) - - return froms - - @property - def froms(self): - """Return the displayed list of FromClause elements.""" - - return self._get_display_froms() - - @_generative - def with_hint(self, selectable, text, dialect_name=None): - """Add an indexing hint for the given selectable to this :class:`Select`. - - The text of the hint is written specific to a specific backend, and - typically uses Python string substitution syntax to render the name - of the table or alias, such as for Oracle:: - - select([mytable]).with_hint(mytable, "+ index(%(name)s ix_mytable)") - - Would render SQL as:: - - select /*+ index(mytable ix_mytable) */ ... from mytable - - The ``dialect_name`` option will limit the rendering of a particular hint - to a particular backend. Such as, to add hints for both Oracle and - Sybase simultaneously:: - - select([mytable]).\ - with_hint(mytable, "+ index(%(name)s ix_mytable)", 'oracle').\ - with_hint(mytable, "WITH INDEX ix_mytable", 'sybase') - - """ - if not dialect_name: - dialect_name = '*' - self._hints = self._hints.union({(selectable, dialect_name):text}) - - @property - def type(self): - raise exc.InvalidRequestError("Select objects don't have a type. " - "Call as_scalar() on this Select object " - "to return a 'scalar' version of this Select.") - - @util.memoized_instancemethod - def locate_all_froms(self): - """return a Set of all FromClause elements referenced by this Select. - - This set is a superset of that returned by the ``froms`` property, which - is specifically for those FromClause elements that would actually be rendered. - - """ - return self._froms.union(_from_objects(*list(self._froms))) - - @property - def inner_columns(self): - """an iterator of all ColumnElement expressions which would - be rendered into the columns clause of the resulting SELECT statement. - - """ - return _select_iterables(self._raw_columns) - - def is_derived_from(self, fromclause): - if self in fromclause._cloned_set: - return True - - for f in self.locate_all_froms(): - if f.is_derived_from(fromclause): - return True - return False - - def _copy_internals(self, clone=_clone): - self._reset_exported() - from_cloned = dict((f, clone(f)) - for f in self._froms.union(self._correlate)) - self._froms = util.OrderedSet(from_cloned[f] for f in self._froms) - self._correlate = set(from_cloned[f] for f in self._correlate) - self._raw_columns = [clone(c) for c in self._raw_columns] - for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): - if getattr(self, attr) is not None: - setattr(self, attr, clone(getattr(self, attr))) - - def get_children(self, column_collections=True, **kwargs): - """return child elements as per the ClauseElement specification.""" - - return (column_collections and list(self.columns) or []) + \ - self._raw_columns + list(self._froms) + \ - [x for x in - (self._whereclause, self._having, - self._order_by_clause, self._group_by_clause) - if x is not None] - - @_generative - def column(self, column): - """return a new select() construct with the given column expression - added to its columns clause. - - """ - - column = _literal_as_column(column) - - if isinstance(column, _ScalarSelect): - column = column.self_group(against=operators.comma_op) - - self._raw_columns = self._raw_columns + [column] - self._froms = self._froms.union(_from_objects(column)) - - @_generative - def with_only_columns(self, columns): - """return a new select() construct with its columns clause replaced - with the given columns. - - """ - - self._raw_columns = [ - isinstance(c, _ScalarSelect) and - c.self_group(against=operators.comma_op) or c - for c in [_literal_as_column(c) for c in columns] - ] - - @_generative - def where(self, whereclause): - """return a new select() construct with the given expression added to its - WHERE clause, joined to the existing clause via AND, if any. - - """ - - self.append_whereclause(whereclause) - - @_generative - def having(self, having): - """return a new select() construct with the given expression added to its HAVING - clause, joined to the existing clause via AND, if any. - - """ - self.append_having(having) - - @_generative - def distinct(self): - """return a new select() construct which will apply DISTINCT to its columns - clause. - - """ - self._distinct = True - - @_generative - def prefix_with(self, clause): - """return a new select() construct which will apply the given expression to the - start of its columns clause, not using any commas. - - """ - clause = _literal_as_text(clause) - self._prefixes = self._prefixes + (clause,) - - @_generative - def select_from(self, fromclause): - """return a new select() construct with the given FROM expression applied to its - list of FROM objects. - - """ - fromclause = _literal_as_text(fromclause) - self._froms = self._froms.union([fromclause]) - - @_generative - def correlate(self, *fromclauses): - """return a new select() construct which will correlate the given FROM clauses to - that of an enclosing select(), if a match is found. - - By "match", the given fromclause must be present in this select's list of FROM - objects and also present in an enclosing select's list of FROM objects. - - Calling this method turns off the select's default behavior of - "auto-correlation". Normally, select() auto-correlates all of its FROM clauses to - those of an embedded select when compiled. - - If the fromclause is None, correlation is disabled for the returned select(). - - """ - self._should_correlate = False - if fromclauses == (None,): - self._correlate = set() - else: - self._correlate = self._correlate.union(fromclauses) - - def append_correlation(self, fromclause): - """append the given correlation expression to this select() construct.""" - - self._should_correlate = False - self._correlate = self._correlate.union([fromclause]) - - def append_column(self, column): - """append the given column expression to the columns clause of this select() - construct. - - """ - column = _literal_as_column(column) - - if isinstance(column, _ScalarSelect): - column = column.self_group(against=operators.comma_op) - - self._raw_columns = self._raw_columns + [column] - self._froms = self._froms.union(_from_objects(column)) - self._reset_exported() - - def append_prefix(self, clause): - """append the given columns clause prefix expression to this select() - construct. - - """ - clause = _literal_as_text(clause) - self._prefixes = self._prefixes + (clause,) - - def append_whereclause(self, whereclause): - """append the given expression to this select() construct's WHERE criterion. - - The expression will be joined to existing WHERE criterion via AND. - - """ - whereclause = _literal_as_text(whereclause) - self._froms = self._froms.union(_from_objects(whereclause)) - - if self._whereclause is not None: - self._whereclause = and_(self._whereclause, whereclause) - else: - self._whereclause = whereclause - - def append_having(self, having): - """append the given expression to this select() construct's HAVING criterion. - - The expression will be joined to existing HAVING criterion via AND. - - """ - if self._having is not None: - self._having = and_(self._having, _literal_as_text(having)) - else: - self._having = _literal_as_text(having) - - def append_from(self, fromclause): - """append the given FromClause expression to this select() construct's FROM - clause. - - """ - if _is_literal(fromclause): - fromclause = _TextClause(fromclause) - - self._froms = self._froms.union([fromclause]) - - def __exportable_columns(self): - for column in self._raw_columns: - if isinstance(column, Selectable): - for co in column.columns: - yield co - elif isinstance(column, ColumnElement): - yield column - else: - continue - - def _populate_column_collection(self): - for c in self.__exportable_columns(): - c._make_proxy(self, name=self.use_labels and c._label or None) - - def self_group(self, against=None): - """return a 'grouping' construct as per the ClauseElement specification. - - This produces an element that can be embedded in an expression. Note that - this method is called automatically as needed when constructing expressions. - - """ - if isinstance(against, CompoundSelect): - return self - return _FromGrouping(self) - - def union(self, other, **kwargs): - """return a SQL UNION of this select() construct against the given selectable.""" - - return union(self, other, **kwargs) - - def union_all(self, other, **kwargs): - """return a SQL UNION ALL of this select() construct against the given - selectable. - - """ - return union_all(self, other, **kwargs) - - def except_(self, other, **kwargs): - """return a SQL EXCEPT of this select() construct against the given selectable.""" - - return except_(self, other, **kwargs) - - def except_all(self, other, **kwargs): - """return a SQL EXCEPT ALL of this select() construct against the given - selectable. - - """ - return except_all(self, other, **kwargs) - - def intersect(self, other, **kwargs): - """return a SQL INTERSECT of this select() construct against the given - selectable. - - """ - return intersect(self, other, **kwargs) - - def intersect_all(self, other, **kwargs): - """return a SQL INTERSECT ALL of this select() construct against the given - selectable. - - """ - return intersect_all(self, other, **kwargs) - - def bind(self): - if self._bind: - return self._bind - if not self._froms: - for c in self._raw_columns: - e = c.bind - if e: - self._bind = e - return e - else: - e = list(self._froms)[0].bind - if e: - self._bind = e - return e - - return None - - def _set_bind(self, bind): - self._bind = bind - bind = property(bind, _set_bind) - -class _UpdateBase(Executable, ClauseElement): - """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" - - __visit_name__ = 'update_base' - - _execution_options = Executable._execution_options.union({'autocommit':True}) - kwargs = util.frozendict() - - def _process_colparams(self, parameters): - if isinstance(parameters, (list, tuple)): - pp = {} - for i, c in enumerate(self.table.c): - pp[c.key] = parameters[i] - return pp - else: - return parameters - - def params(self, *arg, **kw): - raise NotImplementedError( - "params() is not supported for INSERT/UPDATE/DELETE statements." - " To set the values for an INSERT or UPDATE statement, use" - " stmt.values(**parameters).") - - def bind(self): - return self._bind or self.table.bind - - def _set_bind(self, bind): - self._bind = bind - bind = property(bind, _set_bind) - - _returning_re = re.compile(r'(?:firebird|postgres(?:ql)?)_returning') - def _process_deprecated_kw(self, kwargs): - for k in list(kwargs): - m = self._returning_re.match(k) - if m: - self._returning = kwargs.pop(k) - util.warn_deprecated( - "The %r argument is deprecated. Please " - "use statement.returning(col1, col2, ...)" % k - ) - return kwargs - - @_generative - def returning(self, *cols): - """Add a RETURNING or equivalent clause to this statement. - - The given list of columns represent columns within the table - that is the target of the INSERT, UPDATE, or DELETE. Each - element can be any column expression. :class:`~sqlalchemy.schema.Table` - objects will be expanded into their individual columns. - - Upon compilation, a RETURNING clause, or database equivalent, - will be rendered within the statement. For INSERT and UPDATE, - the values are the newly inserted/updated values. For DELETE, - the values are those of the rows which were deleted. - - Upon execution, the values of the columns to be returned - are made available via the result set and can be iterated - using ``fetchone()`` and similar. For DBAPIs which do not - natively support returning values (i.e. cx_oracle), - SQLAlchemy will approximate this behavior at the result level - so that a reasonable amount of behavioral neutrality is - provided. - - Note that not all databases/DBAPIs - support RETURNING. For those backends with no support, - an exception is raised upon compilation and/or execution. - For those who do support it, the functionality across backends - varies greatly, including restrictions on executemany() - and other statements which return multiple rows. Please - read the documentation notes for the database in use in - order to determine the availability of RETURNING. - - """ - self._returning = cols - -class _ValuesBase(_UpdateBase): - - __visit_name__ = 'values_base' - - def __init__(self, table, values): - self.table = table - self.parameters = self._process_colparams(values) - - @_generative - def values(self, *args, **kwargs): - """specify the VALUES clause for an INSERT statement, or the SET clause for an - UPDATE. - - \**kwargs - key= arguments - - \*args - A single dictionary can be sent as the first positional argument. This - allows non-string based keys, such as Column objects, to be used. - - """ - if args: - v = args[0] - else: - v = {} - - if self.parameters is None: - self.parameters = self._process_colparams(v) - self.parameters.update(kwargs) - else: - self.parameters = self.parameters.copy() - self.parameters.update(self._process_colparams(v)) - self.parameters.update(kwargs) - -class Insert(_ValuesBase): - """Represent an INSERT construct. - - The :class:`Insert` object is created using the :func:`insert()` function. - - """ - __visit_name__ = 'insert' - - _prefixes = () - - def __init__(self, - table, - values=None, - inline=False, - bind=None, - prefixes=None, - returning=None, - **kwargs): - _ValuesBase.__init__(self, table, values) - self._bind = bind - self.select = None - self.inline = inline - self._returning = returning - if prefixes: - self._prefixes = tuple([_literal_as_text(p) for p in prefixes]) - - if kwargs: - self.kwargs = self._process_deprecated_kw(kwargs) - - def get_children(self, **kwargs): - if self.select is not None: - return self.select, - else: - return () - - def _copy_internals(self, clone=_clone): - # TODO: coverage - self.parameters = self.parameters.copy() - - @_generative - def prefix_with(self, clause): - """Add a word or expression between INSERT and INTO. Generative. - - If multiple prefixes are supplied, they will be separated with - spaces. - - """ - clause = _literal_as_text(clause) - self._prefixes = self._prefixes + (clause,) - -class Update(_ValuesBase): - """Represent an Update construct. - - The :class:`Update` object is created using the :func:`update()` function. - - """ - __visit_name__ = 'update' - - def __init__(self, - table, - whereclause, - values=None, - inline=False, - bind=None, - returning=None, - **kwargs): - _ValuesBase.__init__(self, table, values) - self._bind = bind - self._returning = returning - if whereclause is not None: - self._whereclause = _literal_as_text(whereclause) - else: - self._whereclause = None - self.inline = inline - - if kwargs: - self.kwargs = self._process_deprecated_kw(kwargs) - - def get_children(self, **kwargs): - if self._whereclause is not None: - return self._whereclause, - else: - return () - - def _copy_internals(self, clone=_clone): - # TODO: coverage - self._whereclause = clone(self._whereclause) - self.parameters = self.parameters.copy() - - @_generative - def where(self, whereclause): - """return a new update() construct with the given expression added to its WHERE - clause, joined to the existing clause via AND, if any. - - """ - if self._whereclause is not None: - self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) - else: - self._whereclause = _literal_as_text(whereclause) - - -class Delete(_UpdateBase): - """Represent a DELETE construct. - - The :class:`Delete` object is created using the :func:`delete()` function. - - """ - - __visit_name__ = 'delete' - - def __init__(self, - table, - whereclause, - bind=None, - returning =None, - **kwargs): - self._bind = bind - self.table = table - self._returning = returning - - if whereclause is not None: - self._whereclause = _literal_as_text(whereclause) - else: - self._whereclause = None - - if kwargs: - self.kwargs = self._process_deprecated_kw(kwargs) - - def get_children(self, **kwargs): - if self._whereclause is not None: - return self._whereclause, - else: - return () - - @_generative - def where(self, whereclause): - """Add the given WHERE clause to a newly returned delete construct.""" - - if self._whereclause is not None: - self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) - else: - self._whereclause = _literal_as_text(whereclause) - - def _copy_internals(self, clone=_clone): - # TODO: coverage - self._whereclause = clone(self._whereclause) - -class _IdentifiedClause(Executable, ClauseElement): - __visit_name__ = 'identified' - _execution_options = Executable._execution_options.union({'autocommit':False}) - quote = None - - def __init__(self, ident): - self.ident = ident - -class SavepointClause(_IdentifiedClause): - __visit_name__ = 'savepoint' - -class RollbackToSavepointClause(_IdentifiedClause): - __visit_name__ = 'rollback_to_savepoint' - -class ReleaseSavepointClause(_IdentifiedClause): - __visit_name__ = 'release_savepoint' - - +_BindParamClause = BindParameter +_Label = Label +_SelectBase = SelectBase +_BinaryExpression = BinaryExpression +_Cast = Cast +_Null = Null +_False = False_ +_True = True_ +_TextClause = TextClause +_UnaryExpression = UnaryExpression +_Case = Case +_Tuple = Tuple +_Over = Over +_Generative = Generative +_TypeClause = TypeClause +_Extract = Extract +_Exists = Exists +_Grouping = Grouping +_FromGrouping = FromGrouping +_ScalarSelect = ScalarSelect diff --git a/sqlalchemy/sql/functions.py b/sqlalchemy/sql/functions.py index 212f81a..08f1d32 100644 --- a/sqlalchemy/sql/functions.py +++ b/sqlalchemy/sql/functions.py @@ -1,104 +1,813 @@ -from sqlalchemy import types as sqltypes -from sqlalchemy.sql.expression import ( - ClauseList, Function, _literal_as_binds, text, _type_from_args - ) -from sqlalchemy.sql import operators -from sqlalchemy.sql.visitors import VisitableType +# sql/functions.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""SQL function API, factories, and built-in functions. + +""" +from . import sqltypes, schema +from .base import Executable, ColumnCollection +from .elements import ClauseList, Cast, Extract, _literal_as_binds, \ + literal_column, _type_from_args, ColumnElement, _clone,\ + Over, BindParameter, FunctionFilter, Grouping, WithinGroup +from .selectable import FromClause, Select, Alias +from . import util as sqlutil +from . import operators +from .visitors import VisitableType +from .. import util +from . import annotation + +_registry = util.defaultdict(dict) + + +def register_function(identifier, fn, package="_default"): + """Associate a callable with a particular func. name. + + This is normally called by _GenericMeta, but is also + available by itself so that a non-Function construct + can be associated with the :data:`.func` accessor (i.e. + CAST, EXTRACT). + + """ + reg = _registry[package] + reg[identifier] = fn + + +class FunctionElement(Executable, ColumnElement, FromClause): + """Base for SQL function-oriented constructs. + + .. seealso:: + + :class:`.Function` - named SQL function. + + :data:`.func` - namespace which produces registered or ad-hoc + :class:`.Function` instances. + + :class:`.GenericFunction` - allows creation of registered function + types. + + """ + + packagenames = () + + def __init__(self, *clauses, **kwargs): + """Construct a :class:`.FunctionElement`. + """ + args = [_literal_as_binds(c, self.name) for c in clauses] + self.clause_expr = ClauseList( + operator=operators.comma_op, + group_contents=True, *args).\ + self_group() + + def _execute_on_connection(self, connection, multiparams, params): + return connection._execute_function(self, multiparams, params) + + @property + def columns(self): + """The set of columns exported by this :class:`.FunctionElement`. + + Function objects currently have no result column names built in; + this method returns a single-element column collection with + an anonymously named column. + + An interim approach to providing named columns for a function + as a FROM clause is to build a :func:`.select` with the + desired columns:: + + from sqlalchemy.sql import column + + stmt = select([column('x'), column('y')]).\ + select_from(func.myfunction()) + + + """ + return ColumnCollection(self.label(None)) + + @util.memoized_property + def clauses(self): + """Return the underlying :class:`.ClauseList` which contains + the arguments for this :class:`.FunctionElement`. + + """ + return self.clause_expr.element + + def over(self, partition_by=None, order_by=None, rows=None, range_=None): + """Produce an OVER clause against this function. + + Used against aggregate or so-called "window" functions, + for database backends that support window functions. + + The expression:: + + func.row_number().over(order_by='x') + + is shorthand for:: + + from sqlalchemy import over + over(func.row_number(), order_by='x') + + See :func:`~.expression.over` for a full description. + + .. versionadded:: 0.7 + + """ + return Over( + self, + partition_by=partition_by, + order_by=order_by, + rows=rows, + range_=range_ + ) + + def within_group(self, *order_by): + """Produce a WITHIN GROUP (ORDER BY expr) clause against this function. + + Used against so-called "ordered set aggregate" and "hypothetical + set aggregate" functions, including :class:`.percentile_cont`, + :class:`.rank`, :class:`.dense_rank`, etc. + + See :func:`~.expression.within_group` for a full description. + + .. versionadded:: 1.1 + + + """ + return WithinGroup(self, *order_by) + + def filter(self, *criterion): + """Produce a FILTER clause against this function. + + Used against aggregate and window functions, + for database backends that support the "FILTER" clause. + + The expression:: + + func.count(1).filter(True) + + is shorthand for:: + + from sqlalchemy import funcfilter + funcfilter(func.count(1), True) + + .. versionadded:: 1.0.0 + + .. seealso:: + + :class:`.FunctionFilter` + + :func:`.funcfilter` + + + """ + if not criterion: + return self + return FunctionFilter(self, *criterion) + + @property + def _from_objects(self): + return self.clauses._from_objects + + def get_children(self, **kwargs): + return self.clause_expr, + + def _copy_internals(self, clone=_clone, **kw): + self.clause_expr = clone(self.clause_expr, **kw) + self._reset_exported() + FunctionElement.clauses._reset(self) + + def within_group_type(self, within_group): + """For types that define their return type as based on the criteria + within a WITHIN GROUP (ORDER BY) expression, called by the + :class:`.WithinGroup` construct. + + Returns None by default, in which case the function's normal ``.type`` + is used. + + """ + + return None + + def alias(self, name=None, flat=False): + r"""Produce a :class:`.Alias` construct against this + :class:`.FunctionElement`. + + This construct wraps the function in a named alias which + is suitable for the FROM clause, in the style accepted for example + by PostgreSQL. + + e.g.:: + + from sqlalchemy.sql import column + + stmt = select([column('data_view')]).\ + select_from(SomeTable).\ + select_from(func.unnest(SomeTable.data).alias('data_view') + ) + + Would produce: + + .. sourcecode:: sql + + SELECT data_view + FROM sometable, unnest(sometable.data) AS data_view + + .. versionadded:: 0.9.8 The :meth:`.FunctionElement.alias` method + is now supported. Previously, this method's behavior was + undefined and did not behave consistently across versions. + + """ + + return Alias(self, name) + + def select(self): + """Produce a :func:`~.expression.select` construct + against this :class:`.FunctionElement`. + + This is shorthand for:: + + s = select([function_element]) + + """ + s = Select([self]) + if self._execution_options: + s = s.execution_options(**self._execution_options) + return s + + def scalar(self): + """Execute this :class:`.FunctionElement` against an embedded + 'bind' and return a scalar value. + + This first calls :meth:`~.FunctionElement.select` to + produce a SELECT construct. + + Note that :class:`.FunctionElement` can be passed to + the :meth:`.Connectable.scalar` method of :class:`.Connection` + or :class:`.Engine`. + + """ + return self.select().execute().scalar() + + def execute(self): + """Execute this :class:`.FunctionElement` against an embedded + 'bind'. + + This first calls :meth:`~.FunctionElement.select` to + produce a SELECT construct. + + Note that :class:`.FunctionElement` can be passed to + the :meth:`.Connectable.execute` method of :class:`.Connection` + or :class:`.Engine`. + + """ + return self.select().execute() + + def _bind_param(self, operator, obj, type_=None): + return BindParameter(None, obj, _compared_to_operator=operator, + _compared_to_type=self.type, unique=True, + type_=type_) + + def self_group(self, against=None): + # for the moment, we are parenthesizing all array-returning + # expressions against getitem. This may need to be made + # more portable if in the future we support other DBs + # besides postgresql. + if against is operators.getitem and \ + isinstance(self.type, sqltypes.ARRAY): + return Grouping(self) + else: + return super(FunctionElement, self).self_group(against=against) + + +class _FunctionGenerator(object): + """Generate :class:`.Function` objects based on getattr calls.""" + + def __init__(self, **opts): + self.__names = [] + self.opts = opts + + def __getattr__(self, name): + # passthru __ attributes; fixes pydoc + if name.startswith('__'): + try: + return self.__dict__[name] + except KeyError: + raise AttributeError(name) + + elif name.endswith('_'): + name = name[0:-1] + f = _FunctionGenerator(**self.opts) + f.__names = list(self.__names) + [name] + return f + + def __call__(self, *c, **kwargs): + o = self.opts.copy() + o.update(kwargs) + + tokens = len(self.__names) + + if tokens == 2: + package, fname = self.__names + elif tokens == 1: + package, fname = "_default", self.__names[0] + else: + package = None + + if package is not None: + func = _registry[package].get(fname) + if func is not None: + return func(*c, **o) + + return Function(self.__names[-1], + packagenames=self.__names[0:-1], *c, **o) + + +func = _FunctionGenerator() +"""Generate SQL function expressions. + + :data:`.func` is a special object instance which generates SQL + functions based on name-based attributes, e.g.:: + + >>> print(func.count(1)) + count(:param_1) + + The element is a column-oriented SQL element like any other, and is + used in that way:: + + >>> print(select([func.count(table.c.id)])) + SELECT count(sometable.id) FROM sometable + + Any name can be given to :data:`.func`. If the function name is unknown to + SQLAlchemy, it will be rendered exactly as is. For common SQL functions + which SQLAlchemy is aware of, the name may be interpreted as a *generic + function* which will be compiled appropriately to the target database:: + + >>> print(func.current_timestamp()) + CURRENT_TIMESTAMP + + To call functions which are present in dot-separated packages, + specify them in the same manner:: + + >>> print(func.stats.yield_curve(5, 10)) + stats.yield_curve(:yield_curve_1, :yield_curve_2) + + SQLAlchemy can be made aware of the return type of functions to enable + type-specific lexical and result-based behavior. For example, to ensure + that a string-based function returns a Unicode value and is similarly + treated as a string in expressions, specify + :class:`~sqlalchemy.types.Unicode` as the type: + + >>> print(func.my_string(u'hi', type_=Unicode) + ' ' + + ... func.my_string(u'there', type_=Unicode)) + my_string(:my_string_1) || :my_string_2 || my_string(:my_string_3) + + The object returned by a :data:`.func` call is usually an instance of + :class:`.Function`. + This object meets the "column" interface, including comparison and labeling + functions. The object can also be passed the :meth:`~.Connectable.execute` + method of a :class:`.Connection` or :class:`.Engine`, where it will be + wrapped inside of a SELECT statement first:: + + print(connection.execute(func.current_timestamp()).scalar()) + + In a few exception cases, the :data:`.func` accessor + will redirect a name to a built-in expression such as :func:`.cast` + or :func:`.extract`, as these names have well-known meaning + but are not exactly the same as "functions" from a SQLAlchemy + perspective. + + .. versionadded:: 0.8 :data:`.func` can return non-function expression + constructs for common quasi-functional names like :func:`.cast` + and :func:`.extract`. + + Functions which are interpreted as "generic" functions know how to + calculate their return type automatically. For a listing of known generic + functions, see :ref:`generic_functions`. + + .. note:: + + The :data:`.func` construct has only limited support for calling + standalone "stored procedures", especially those with special + parameterization concerns. + + See the section :ref:`stored_procedures` for details on how to use + the DBAPI-level ``callproc()`` method for fully traditional stored + procedures. + +""" + +modifier = _FunctionGenerator(group=False) + + +class Function(FunctionElement): + """Describe a named SQL function. + + See the superclass :class:`.FunctionElement` for a description + of public methods. + + .. seealso:: + + :data:`.func` - namespace which produces registered or ad-hoc + :class:`.Function` instances. + + :class:`.GenericFunction` - allows creation of registered function + types. + + """ + + __visit_name__ = 'function' + + def __init__(self, name, *clauses, **kw): + """Construct a :class:`.Function`. + + The :data:`.func` construct is normally used to construct + new :class:`.Function` instances. + + """ + self.packagenames = kw.pop('packagenames', None) or [] + self.name = name + self._bind = kw.get('bind', None) + self.type = sqltypes.to_instance(kw.get('type_', None)) + + FunctionElement.__init__(self, *clauses, **kw) + + def _bind_param(self, operator, obj, type_=None): + return BindParameter(self.name, obj, + _compared_to_operator=operator, + _compared_to_type=self.type, + type_=type_, + unique=True) + class _GenericMeta(VisitableType): - def __call__(self, *args, **kwargs): - args = [_literal_as_binds(c) for c in args] - return type.__call__(self, *args, **kwargs) + def __init__(cls, clsname, bases, clsdict): + if annotation.Annotated not in cls.__mro__: + cls.name = name = clsdict.get('name', clsname) + cls.identifier = identifier = clsdict.get('identifier', name) + package = clsdict.pop('package', '_default') + # legacy + if '__return_type__' in clsdict: + cls.type = clsdict['__return_type__'] + register_function(identifier, cls, package) + super(_GenericMeta, cls).__init__(clsname, bases, clsdict) -class GenericFunction(Function): - __metaclass__ = _GenericMeta - def __init__(self, type_=None, args=(), **kwargs): +class GenericFunction(util.with_metaclass(_GenericMeta, Function)): + """Define a 'generic' function. + + A generic function is a pre-established :class:`.Function` + class that is instantiated automatically when called + by name from the :data:`.func` attribute. Note that + calling any name from :data:`.func` has the effect that + a new :class:`.Function` instance is created automatically, + given that name. The primary use case for defining + a :class:`.GenericFunction` class is so that a function + of a particular name may be given a fixed return type. + It can also include custom argument parsing schemes as well + as additional methods. + + Subclasses of :class:`.GenericFunction` are automatically + registered under the name of the class. For + example, a user-defined function ``as_utc()`` would + be available immediately:: + + from sqlalchemy.sql.functions import GenericFunction + from sqlalchemy.types import DateTime + + class as_utc(GenericFunction): + type = DateTime + + print select([func.as_utc()]) + + User-defined generic functions can be organized into + packages by specifying the "package" attribute when defining + :class:`.GenericFunction`. Third party libraries + containing many functions may want to use this in order + to avoid name conflicts with other systems. For example, + if our ``as_utc()`` function were part of a package + "time":: + + class as_utc(GenericFunction): + type = DateTime + package = "time" + + The above function would be available from :data:`.func` + using the package name ``time``:: + + print select([func.time.as_utc()]) + + A final option is to allow the function to be accessed + from one name in :data:`.func` but to render as a different name. + The ``identifier`` attribute will override the name used to + access the function as loaded from :data:`.func`, but will retain + the usage of ``name`` as the rendered name:: + + class GeoBuffer(GenericFunction): + type = Geometry + package = "geo" + name = "ST_Buffer" + identifier = "buffer" + + The above function will render as follows:: + + >>> print func.geo.buffer() + ST_Buffer() + + .. versionadded:: 0.8 :class:`.GenericFunction` now supports + automatic registration of new functions as well as package + and custom naming support. + + .. versionchanged:: 0.8 The attribute name ``type`` is used + to specify the function's return type at the class level. + Previously, the name ``__return_type__`` was used. This + name is still recognized for backwards-compatibility. + + """ + + coerce_arguments = True + + def __init__(self, *args, **kwargs): + parsed_args = kwargs.pop('_parsed_args', None) + if parsed_args is None: + parsed_args = [_literal_as_binds(c, self.name) for c in args] self.packagenames = [] - self.name = self.__class__.__name__ self._bind = kwargs.get('bind', None) self.clause_expr = ClauseList( - operator=operators.comma_op, - group_contents=True, *args).self_group() + operator=operators.comma_op, + group_contents=True, *parsed_args).self_group() self.type = sqltypes.to_instance( - type_ or getattr(self, '__return_type__', None)) + kwargs.pop("type_", None) or getattr(self, 'type', None)) + +register_function("cast", Cast) +register_function("extract", Extract) + + +class next_value(GenericFunction): + """Represent the 'next value', given a :class:`.Sequence` + as its single argument. + + Compiles into the appropriate function on each backend, + or will raise NotImplementedError if used on a backend + that does not provide support for sequences. + + """ + type = sqltypes.Integer() + name = "next_value" + + def __init__(self, seq, **kw): + assert isinstance(seq, schema.Sequence), \ + "next_value() accepts a Sequence object as input." + self._bind = kw.get('bind', None) + self.sequence = seq + + @property + def _from_objects(self): + return [] + class AnsiFunction(GenericFunction): def __init__(self, **kwargs): GenericFunction.__init__(self, **kwargs) + class ReturnTypeFromArgs(GenericFunction): """Define a function whose return type is the same as its arguments.""" - + def __init__(self, *args, **kwargs): + args = [_literal_as_binds(c, self.name) for c in args] kwargs.setdefault('type_', _type_from_args(args)) - GenericFunction.__init__(self, args=args, **kwargs) + kwargs['_parsed_args'] = args + super(ReturnTypeFromArgs, self).__init__(*args, **kwargs) + class coalesce(ReturnTypeFromArgs): pass + class max(ReturnTypeFromArgs): pass + class min(ReturnTypeFromArgs): pass + class sum(ReturnTypeFromArgs): pass + class now(GenericFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime + class concat(GenericFunction): - __return_type__ = sqltypes.String - def __init__(self, *args, **kwargs): - GenericFunction.__init__(self, args=args, **kwargs) + type = sqltypes.String + class char_length(GenericFunction): - __return_type__ = sqltypes.Integer + type = sqltypes.Integer def __init__(self, arg, **kwargs): - GenericFunction.__init__(self, args=[arg], **kwargs) + GenericFunction.__init__(self, arg, **kwargs) + class random(GenericFunction): - def __init__(self, *args, **kwargs): - kwargs.setdefault('type_', None) - GenericFunction.__init__(self, args=args, **kwargs) + pass + class count(GenericFunction): - """The ANSI COUNT aggregate function. With no arguments, emits COUNT \*.""" + r"""The ANSI COUNT aggregate function. With no arguments, + emits COUNT \*. - __return_type__ = sqltypes.Integer + """ + type = sqltypes.Integer def __init__(self, expression=None, **kwargs): if expression is None: - expression = text('*') - GenericFunction.__init__(self, args=(expression,), **kwargs) + expression = literal_column('*') + super(count, self).__init__(expression, **kwargs) + class current_date(AnsiFunction): - __return_type__ = sqltypes.Date + type = sqltypes.Date + class current_time(AnsiFunction): - __return_type__ = sqltypes.Time + type = sqltypes.Time + class current_timestamp(AnsiFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime + class current_user(AnsiFunction): - __return_type__ = sqltypes.String + type = sqltypes.String + class localtime(AnsiFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime + class localtimestamp(AnsiFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime + class session_user(AnsiFunction): - __return_type__ = sqltypes.String + type = sqltypes.String + class sysdate(AnsiFunction): - __return_type__ = sqltypes.DateTime + type = sqltypes.DateTime + class user(AnsiFunction): - __return_type__ = sqltypes.String + type = sqltypes.String + +class array_agg(GenericFunction): + """support for the ARRAY_AGG function. + + The ``func.array_agg(expr)`` construct returns an expression of + type :class:`.types.ARRAY`. + + e.g.:: + + stmt = select([func.array_agg(table.c.values)[2:5]]) + + .. versionadded:: 1.1 + + .. seealso:: + + :func:`.postgresql.array_agg` - PostgreSQL-specific version that + returns :class:`.postgresql.ARRAY`, which has PG-specific operators added. + + """ + + type = sqltypes.ARRAY + + def __init__(self, *args, **kwargs): + args = [_literal_as_binds(c) for c in args] + kwargs.setdefault('type_', self.type(_type_from_args(args))) + kwargs['_parsed_args'] = args + super(array_agg, self).__init__(*args, **kwargs) + + +class OrderedSetAgg(GenericFunction): + """Define a function where the return type is based on the sort + expression type as defined by the expression passed to the + :meth:`.FunctionElement.within_group` method.""" + + array_for_multi_clause = False + + def within_group_type(self, within_group): + func_clauses = self.clause_expr.element + order_by = sqlutil.unwrap_order_by(within_group.order_by) + if self.array_for_multi_clause and len(func_clauses.clauses) > 1: + return sqltypes.ARRAY(order_by[0].type) + else: + return order_by[0].type + + +class mode(OrderedSetAgg): + """implement the ``mode`` ordered-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is the same as the sort expression. + + .. versionadded:: 1.1 + + """ + + +class percentile_cont(OrderedSetAgg): + """implement the ``percentile_cont`` ordered-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is the same as the sort expression, + or if the arguments are an array, an :class:`.types.ARRAY` of the sort + expression's type. + + .. versionadded:: 1.1 + + """ + + array_for_multi_clause = True + + +class percentile_disc(OrderedSetAgg): + """implement the ``percentile_disc`` ordered-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is the same as the sort expression, + or if the arguments are an array, an :class:`.types.ARRAY` of the sort + expression's type. + + .. versionadded:: 1.1 + + """ + + array_for_multi_clause = True + + +class rank(GenericFunction): + """Implement the ``rank`` hypothetical-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is :class:`.Integer`. + + .. versionadded:: 1.1 + + """ + type = sqltypes.Integer() + + +class dense_rank(GenericFunction): + """Implement the ``dense_rank`` hypothetical-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is :class:`.Integer`. + + .. versionadded:: 1.1 + + """ + type = sqltypes.Integer() + + +class percent_rank(GenericFunction): + """Implement the ``percent_rank`` hypothetical-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is :class:`.Numeric`. + + .. versionadded:: 1.1 + + """ + type = sqltypes.Numeric() + + +class cume_dist(GenericFunction): + """Implement the ``cume_dist`` hypothetical-set aggregate function. + + This function must be used with the :meth:`.FunctionElement.within_group` + modifier to supply a sort expression to operate upon. + + The return type of this function is :class:`.Numeric`. + + .. versionadded:: 1.1 + + """ + type = sqltypes.Numeric() diff --git a/sqlalchemy/sql/operators.py b/sqlalchemy/sql/operators.py index 6f70b17..d883392 100644 --- a/sqlalchemy/sql/operators.py +++ b/sqlalchemy/sql/operators.py @@ -1,135 +1,1014 @@ +# sql/operators.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """Defines operators used in SQL expressions.""" -from operator import ( - and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq, neg - ) - -# Py2K -from operator import (div,) -# end Py2K +from .. import util -from sqlalchemy.util import symbol +from operator import ( + and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq, neg, + getitem, lshift, rshift, contains +) + +if util.py2k: + from operator import div +else: + div = truediv + + +class Operators(object): + """Base of comparison and logical operators. + + Implements base methods + :meth:`~sqlalchemy.sql.operators.Operators.operate` and + :meth:`~sqlalchemy.sql.operators.Operators.reverse_operate`, as well as + :meth:`~sqlalchemy.sql.operators.Operators.__and__`, + :meth:`~sqlalchemy.sql.operators.Operators.__or__`, + :meth:`~sqlalchemy.sql.operators.Operators.__invert__`. + + Usually is used via its most common subclass + :class:`.ColumnOperators`. + + """ + __slots__ = () + + def __and__(self, other): + """Implement the ``&`` operator. + + When used with SQL expressions, results in an + AND operation, equivalent to + :func:`~.expression.and_`, that is:: + + a & b + + is equivalent to:: + + from sqlalchemy import and_ + and_(a, b) + + Care should be taken when using ``&`` regarding + operator precedence; the ``&`` operator has the highest precedence. + The operands should be enclosed in parenthesis if they contain + further sub expressions:: + + (a == 2) & (b == 4) + + """ + return self.operate(and_, other) + + def __or__(self, other): + """Implement the ``|`` operator. + + When used with SQL expressions, results in an + OR operation, equivalent to + :func:`~.expression.or_`, that is:: + + a | b + + is equivalent to:: + + from sqlalchemy import or_ + or_(a, b) + + Care should be taken when using ``|`` regarding + operator precedence; the ``|`` operator has the highest precedence. + The operands should be enclosed in parenthesis if they contain + further sub expressions:: + + (a == 2) | (b == 4) + + """ + return self.operate(or_, other) + + def __invert__(self): + """Implement the ``~`` operator. + + When used with SQL expressions, results in a + NOT operation, equivalent to + :func:`~.expression.not_`, that is:: + + ~a + + is equivalent to:: + + from sqlalchemy import not_ + not_(a) + + """ + return self.operate(inv) + + def op(self, opstring, precedence=0, is_comparison=False): + """produce a generic operator function. + + e.g.:: + + somecolumn.op("*")(5) + + produces:: + + somecolumn * 5 + + This function can also be used to make bitwise operators explicit. For + example:: + + somecolumn.op('&')(0xff) + + is a bitwise AND of the value in ``somecolumn``. + + :param operator: a string which will be output as the infix operator + between this element and the expression passed to the + generated function. + + :param precedence: precedence to apply to the operator, when + parenthesizing expressions. A lower number will cause the expression + to be parenthesized when applied against another operator with + higher precedence. The default value of ``0`` is lower than all + operators except for the comma (``,``) and ``AS`` operators. + A value of 100 will be higher or equal to all operators, and -100 + will be lower than or equal to all operators. + + .. versionadded:: 0.8 - added the 'precedence' argument. + + :param is_comparison: if True, the operator will be considered as a + "comparison" operator, that is which evaluates to a boolean + true/false value, like ``==``, ``>``, etc. This flag should be set + so that ORM relationships can establish that the operator is a + comparison operator when used in a custom join condition. + + .. versionadded:: 0.9.2 - added the + :paramref:`.Operators.op.is_comparison` flag. + + .. seealso:: + + :ref:`types_operators` + + :ref:`relationship_custom_operator` + + """ + operator = custom_op(opstring, precedence, is_comparison) + + def against(other): + return operator(self, other) + return against + + def operate(self, op, *other, **kwargs): + r"""Operate on an argument. + + This is the lowest level of operation, raises + :class:`NotImplementedError` by default. + + Overriding this on a subclass can allow common + behavior to be applied to all operations. + For example, overriding :class:`.ColumnOperators` + to apply ``func.lower()`` to the left and right + side:: + + class MyComparator(ColumnOperators): + def operate(self, op, other): + return op(func.lower(self), func.lower(other)) + + :param op: Operator callable. + :param \*other: the 'other' side of the operation. Will + be a single scalar for most operations. + :param \**kwargs: modifiers. These may be passed by special + operators such as :meth:`ColumnOperators.contains`. + + + """ + raise NotImplementedError(str(op)) + + def reverse_operate(self, op, other, **kwargs): + """Reverse operate on an argument. + + Usage is the same as :meth:`operate`. + + """ + raise NotImplementedError(str(op)) + + +class custom_op(object): + """Represent a 'custom' operator. + + :class:`.custom_op` is normally instantitated when the + :meth:`.ColumnOperators.op` method is used to create a + custom operator callable. The class can also be used directly + when programmatically constructing expressions. E.g. + to represent the "factorial" operation:: + + from sqlalchemy.sql import UnaryExpression + from sqlalchemy.sql import operators + from sqlalchemy import Numeric + + unary = UnaryExpression(table.c.somecolumn, + modifier=operators.custom_op("!"), + type_=Numeric) + + """ + __name__ = 'custom_op' + + def __init__( + self, opstring, precedence=0, is_comparison=False, + natural_self_precedent=False, eager_grouping=False): + self.opstring = opstring + self.precedence = precedence + self.is_comparison = is_comparison + self.natural_self_precedent = natural_self_precedent + self.eager_grouping = eager_grouping + + def __eq__(self, other): + return isinstance(other, custom_op) and \ + other.opstring == self.opstring + + def __hash__(self): + return id(self) + + def __call__(self, left, right, **kw): + return left.operate(self, right, **kw) + + +class ColumnOperators(Operators): + """Defines boolean, comparison, and other operators for + :class:`.ColumnElement` expressions. + + By default, all methods call down to + :meth:`.operate` or :meth:`.reverse_operate`, + passing in the appropriate operator function from the + Python builtin ``operator`` module or + a SQLAlchemy-specific operator function from + :mod:`sqlalchemy.expression.operators`. For example + the ``__eq__`` function:: + + def __eq__(self, other): + return self.operate(operators.eq, other) + + Where ``operators.eq`` is essentially:: + + def eq(a, b): + return a == b + + The core column expression unit :class:`.ColumnElement` + overrides :meth:`.Operators.operate` and others + to return further :class:`.ColumnElement` constructs, + so that the ``==`` operation above is replaced by a clause + construct. + + See also: + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + :class:`.ColumnOperators` + + :class:`.PropComparator` + + """ + + __slots__ = () + + timetuple = None + """Hack, allows datetime objects to be compared on the LHS.""" + + def __lt__(self, other): + """Implement the ``<`` operator. + + In a column context, produces the clause ``a < b``. + + """ + return self.operate(lt, other) + + def __le__(self, other): + """Implement the ``<=`` operator. + + In a column context, produces the clause ``a <= b``. + + """ + return self.operate(le, other) + + __hash__ = Operators.__hash__ + + def __eq__(self, other): + """Implement the ``==`` operator. + + In a column context, produces the clause ``a = b``. + If the target is ``None``, produces ``a IS NULL``. + + """ + return self.operate(eq, other) + + def __ne__(self, other): + """Implement the ``!=`` operator. + + In a column context, produces the clause ``a != b``. + If the target is ``None``, produces ``a IS NOT NULL``. + + """ + return self.operate(ne, other) + + def is_distinct_from(self, other): + """Implement the ``IS DISTINCT FROM`` operator. + + Renders "a IS DISTINCT FROM b" on most platforms; + on some such as SQLite may render "a IS NOT b". + + .. versionadded:: 1.1 + + """ + return self.operate(is_distinct_from, other) + + def isnot_distinct_from(self, other): + """Implement the ``IS NOT DISTINCT FROM`` operator. + + Renders "a IS NOT DISTINCT FROM b" on most platforms; + on some such as SQLite may render "a IS b". + + .. versionadded:: 1.1 + + """ + return self.operate(isnot_distinct_from, other) + + def __gt__(self, other): + """Implement the ``>`` operator. + + In a column context, produces the clause ``a > b``. + + """ + return self.operate(gt, other) + + def __ge__(self, other): + """Implement the ``>=`` operator. + + In a column context, produces the clause ``a >= b``. + + """ + return self.operate(ge, other) + + def __neg__(self): + """Implement the ``-`` operator. + + In a column context, produces the clause ``-a``. + + """ + return self.operate(neg) + + def __contains__(self, other): + return self.operate(contains, other) + + def __getitem__(self, index): + """Implement the [] operator. + + This can be used by some database-specific types + such as PostgreSQL ARRAY and HSTORE. + + """ + return self.operate(getitem, index) + + def __lshift__(self, other): + """implement the << operator. + + Not used by SQLAlchemy core, this is provided + for custom operator systems which want to use + << as an extension point. + """ + return self.operate(lshift, other) + + def __rshift__(self, other): + """implement the >> operator. + + Not used by SQLAlchemy core, this is provided + for custom operator systems which want to use + >> as an extension point. + """ + return self.operate(rshift, other) + + def concat(self, other): + """Implement the 'concat' operator. + + In a column context, produces the clause ``a || b``, + or uses the ``concat()`` operator on MySQL. + + """ + return self.operate(concat_op, other) + + def like(self, other, escape=None): + r"""Implement the ``like`` operator. + + In a column context, produces the expression:: + + a LIKE other + + E.g.:: + + stmt = select([sometable]).\ + where(sometable.c.column.like("%foobar%")) + + :param other: expression to be compared + :param escape: optional escape character, renders the ``ESCAPE`` + keyword, e.g.:: + + somecolumn.like("foo/%bar", escape="/") + + .. seealso:: + + :meth:`.ColumnOperators.ilike` + + """ + return self.operate(like_op, other, escape=escape) + + def ilike(self, other, escape=None): + r"""Implement the ``ilike`` operator, e.g. case insensitive LIKE. + + In a column context, produces an expression either of the form:: + + lower(a) LIKE lower(other) + + Or on backends that support the ILIKE operator:: + + a ILIKE other + + E.g.:: + + stmt = select([sometable]).\ + where(sometable.c.column.ilike("%foobar%")) + + :param other: expression to be compared + :param escape: optional escape character, renders the ``ESCAPE`` + keyword, e.g.:: + + somecolumn.ilike("foo/%bar", escape="/") + + .. seealso:: + + :meth:`.ColumnOperators.like` + + """ + return self.operate(ilike_op, other, escape=escape) + + def in_(self, other): + """Implement the ``in`` operator. + + In a column context, produces the clause ``a IN other``. + "other" may be a tuple/list of column expressions, + or a :func:`~.expression.select` construct. + + """ + return self.operate(in_op, other) + + def notin_(self, other): + """implement the ``NOT IN`` operator. + + This is equivalent to using negation with + :meth:`.ColumnOperators.in_`, i.e. ``~x.in_(y)``. + + .. versionadded:: 0.8 + + .. seealso:: + + :meth:`.ColumnOperators.in_` + + """ + return self.operate(notin_op, other) + + def notlike(self, other, escape=None): + """implement the ``NOT LIKE`` operator. + + This is equivalent to using negation with + :meth:`.ColumnOperators.like`, i.e. ``~x.like(y)``. + + .. versionadded:: 0.8 + + .. seealso:: + + :meth:`.ColumnOperators.like` + + """ + return self.operate(notlike_op, other, escape=escape) + + def notilike(self, other, escape=None): + """implement the ``NOT ILIKE`` operator. + + This is equivalent to using negation with + :meth:`.ColumnOperators.ilike`, i.e. ``~x.ilike(y)``. + + .. versionadded:: 0.8 + + .. seealso:: + + :meth:`.ColumnOperators.ilike` + + """ + return self.operate(notilike_op, other, escape=escape) + + def is_(self, other): + """Implement the ``IS`` operator. + + Normally, ``IS`` is generated automatically when comparing to a + value of ``None``, which resolves to ``NULL``. However, explicit + usage of ``IS`` may be desirable if comparing to boolean values + on certain platforms. + + .. versionadded:: 0.7.9 + + .. seealso:: :meth:`.ColumnOperators.isnot` + + """ + return self.operate(is_, other) + + def isnot(self, other): + """Implement the ``IS NOT`` operator. + + Normally, ``IS NOT`` is generated automatically when comparing to a + value of ``None``, which resolves to ``NULL``. However, explicit + usage of ``IS NOT`` may be desirable if comparing to boolean values + on certain platforms. + + .. versionadded:: 0.7.9 + + .. seealso:: :meth:`.ColumnOperators.is_` + + """ + return self.operate(isnot, other) + + def startswith(self, other, **kwargs): + """Implement the ``startwith`` operator. + + In a column context, produces the clause ``LIKE '%'`` + + """ + return self.operate(startswith_op, other, **kwargs) + + def endswith(self, other, **kwargs): + """Implement the 'endswith' operator. + + In a column context, produces the clause ``LIKE '%'`` + + """ + return self.operate(endswith_op, other, **kwargs) + + def contains(self, other, **kwargs): + """Implement the 'contains' operator. + + In a column context, produces the clause ``LIKE '%%'`` + + """ + return self.operate(contains_op, other, **kwargs) + + def match(self, other, **kwargs): + """Implements a database-specific 'match' operator. + + :meth:`~.ColumnOperators.match` attempts to resolve to + a MATCH-like function or operator provided by the backend. + Examples include: + + * PostgreSQL - renders ``x @@ to_tsquery(y)`` + * MySQL - renders ``MATCH (x) AGAINST (y IN BOOLEAN MODE)`` + * Oracle - renders ``CONTAINS(x, y)`` + * other backends may provide special implementations. + * Backends without any special implementation will emit + the operator as "MATCH". This is compatible with SQlite, for + example. + + """ + return self.operate(match_op, other, **kwargs) + + def desc(self): + """Produce a :func:`~.expression.desc` clause against the + parent object.""" + return self.operate(desc_op) + + def asc(self): + """Produce a :func:`~.expression.asc` clause against the + parent object.""" + return self.operate(asc_op) + + def nullsfirst(self): + """Produce a :func:`~.expression.nullsfirst` clause against the + parent object.""" + return self.operate(nullsfirst_op) + + def nullslast(self): + """Produce a :func:`~.expression.nullslast` clause against the + parent object.""" + return self.operate(nullslast_op) + + def collate(self, collation): + """Produce a :func:`~.expression.collate` clause against + the parent object, given the collation string.""" + return self.operate(collate, collation) + + def __radd__(self, other): + """Implement the ``+`` operator in reverse. + + See :meth:`.ColumnOperators.__add__`. + + """ + return self.reverse_operate(add, other) + + def __rsub__(self, other): + """Implement the ``-`` operator in reverse. + + See :meth:`.ColumnOperators.__sub__`. + + """ + return self.reverse_operate(sub, other) + + def __rmul__(self, other): + """Implement the ``*`` operator in reverse. + + See :meth:`.ColumnOperators.__mul__`. + + """ + return self.reverse_operate(mul, other) + + def __rdiv__(self, other): + """Implement the ``/`` operator in reverse. + + See :meth:`.ColumnOperators.__div__`. + + """ + return self.reverse_operate(div, other) + + def __rmod__(self, other): + """Implement the ``%`` operator in reverse. + + See :meth:`.ColumnOperators.__mod__`. + + """ + return self.reverse_operate(mod, other) + + def between(self, cleft, cright, symmetric=False): + """Produce a :func:`~.expression.between` clause against + the parent object, given the lower and upper range. + + """ + return self.operate(between_op, cleft, cright, symmetric=symmetric) + + def distinct(self): + """Produce a :func:`~.expression.distinct` clause against the + parent object. + + """ + return self.operate(distinct_op) + + def any_(self): + """Produce a :func:`~.expression.any_` clause against the + parent object. + + .. versionadded:: 1.1 + + """ + return self.operate(any_op) + + def all_(self): + """Produce a :func:`~.expression.all_` clause against the + parent object. + + .. versionadded:: 1.1 + + """ + return self.operate(all_op) + + def __add__(self, other): + """Implement the ``+`` operator. + + In a column context, produces the clause ``a + b`` + if the parent object has non-string affinity. + If the parent object has a string affinity, + produces the concatenation operator, ``a || b`` - + see :meth:`.ColumnOperators.concat`. + + """ + return self.operate(add, other) + + def __sub__(self, other): + """Implement the ``-`` operator. + + In a column context, produces the clause ``a - b``. + + """ + return self.operate(sub, other) + + def __mul__(self, other): + """Implement the ``*`` operator. + + In a column context, produces the clause ``a * b``. + + """ + return self.operate(mul, other) + + def __div__(self, other): + """Implement the ``/`` operator. + + In a column context, produces the clause ``a / b``. + + """ + return self.operate(div, other) + + def __mod__(self, other): + """Implement the ``%`` operator. + + In a column context, produces the clause ``a % b``. + + """ + return self.operate(mod, other) + + def __truediv__(self, other): + """Implement the ``//`` operator. + + In a column context, produces the clause ``a / b``. + + """ + return self.operate(truediv, other) + + def __rtruediv__(self, other): + """Implement the ``//`` operator in reverse. + + See :meth:`.ColumnOperators.__truediv__`. + + """ + return self.reverse_operate(truediv, other) def from_(): raise NotImplementedError() + def as_(): raise NotImplementedError() + def exists(): raise NotImplementedError() -def is_(): + +def istrue(a): raise NotImplementedError() -def isnot(): + +def isfalse(a): raise NotImplementedError() -def collate(): - raise NotImplementedError() + +def is_distinct_from(a, b): + return a.is_distinct_from(b) + + +def isnot_distinct_from(a, b): + return a.isnot_distinct_from(b) + + +def is_(a, b): + return a.is_(b) + + +def isnot(a, b): + return a.isnot(b) + + +def collate(a, b): + return a.collate(b) + def op(a, opstring, b): return a.op(opstring)(b) + def like_op(a, b, escape=None): return a.like(b, escape=escape) + def notlike_op(a, b, escape=None): - raise NotImplementedError() + return a.notlike(b, escape=escape) + def ilike_op(a, b, escape=None): return a.ilike(b, escape=escape) -def notilike_op(a, b, escape=None): - raise NotImplementedError() -def between_op(a, b, c): - return a.between(b, c) +def notilike_op(a, b, escape=None): + return a.notilike(b, escape=escape) + + +def between_op(a, b, c, symmetric=False): + return a.between(b, c, symmetric=symmetric) + + +def notbetween_op(a, b, c, symmetric=False): + return a.notbetween(b, c, symmetric=symmetric) + def in_op(a, b): return a.in_(b) + def notin_op(a, b): - raise NotImplementedError() + return a.notin_(b) + def distinct_op(a): return a.distinct() + +def any_op(a): + return a.any_() + + +def all_op(a): + return a.all_() + + def startswith_op(a, b, escape=None): return a.startswith(b, escape=escape) + +def notstartswith_op(a, b, escape=None): + return ~a.startswith(b, escape=escape) + + def endswith_op(a, b, escape=None): return a.endswith(b, escape=escape) + +def notendswith_op(a, b, escape=None): + return ~a.endswith(b, escape=escape) + + def contains_op(a, b, escape=None): return a.contains(b, escape=escape) -def match_op(a, b): - return a.match(b) + +def notcontains_op(a, b, escape=None): + return ~a.contains(b, escape=escape) + + +def match_op(a, b, **kw): + return a.match(b, **kw) + + +def notmatch_op(a, b, **kw): + return a.notmatch(b, **kw) + def comma_op(a, b): raise NotImplementedError() + def concat_op(a, b): return a.concat(b) + def desc_op(a): return a.desc() + def asc_op(a): return a.asc() + +def nullsfirst_op(a): + return a.nullsfirst() + + +def nullslast_op(a): + return a.nullslast() + + +def json_getitem_op(a, b): + raise NotImplementedError() + + +def json_path_getitem_op(a, b): + raise NotImplementedError() + + _commutative = set([eq, ne, add, mul]) + +_comparison = set([eq, ne, lt, gt, ge, le, between_op, like_op]) + + +def is_comparison(op): + return op in _comparison or \ + isinstance(op, custom_op) and op.is_comparison + + def is_commutative(op): return op in _commutative -_smallest = symbol('_smallest') -_largest = symbol('_largest') + +def is_ordering_modifier(op): + return op in (asc_op, desc_op, + nullsfirst_op, nullslast_op) + + +def is_natural_self_precedent(op): + return op in _natural_self_precedent or \ + isinstance(op, custom_op) and op.natural_self_precedent + +_mirror = { + gt: lt, + ge: le, + lt: gt, + le: ge +} + + +def mirror(op): + """rotate a comparison operator 180 degrees. + + Note this is not the same as negation. + + """ + return _mirror.get(op, op) + + +_associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne]) + +_natural_self_precedent = _associative.union([ + getitem, json_getitem_op, json_path_getitem_op]) +"""Operators where if we have (a op b) op c, we don't want to +parenthesize (a op b). + +""" + + +_asbool = util.symbol('_asbool', canonical=-10) +_smallest = util.symbol('_smallest', canonical=-100) +_largest = util.symbol('_largest', canonical=100) _PRECEDENCE = { from_: 15, - mul: 7, - truediv: 7, - # Py2K - div: 7, - # end Py2K - mod: 7, - neg: 7, - add: 6, - sub: 6, + any_op: 15, + all_op: 15, + getitem: 15, + json_getitem_op: 15, + json_path_getitem_op: 15, + + mul: 8, + truediv: 8, + div: 8, + mod: 8, + neg: 8, + add: 7, + sub: 7, + concat_op: 6, match_op: 6, - ilike_op: 5, - notilike_op: 5, - like_op: 5, - notlike_op: 5, - in_op: 5, - notin_op: 5, - is_: 5, - isnot: 5, + notmatch_op: 6, + + ilike_op: 6, + notilike_op: 6, + like_op: 6, + notlike_op: 6, + in_op: 6, + notin_op: 6, + + is_: 6, + isnot: 6, + eq: 5, ne: 5, + is_distinct_from: 5, + isnot_distinct_from: 5, gt: 5, lt: 5, ge: 5, le: 5, + between_op: 5, + notbetween_op: 5, distinct_op: 5, inv: 5, + istrue: 5, + isfalse: 5, and_: 3, or_: 2, comma_op: -1, - collate: 7, + + desc_op: 3, + asc_op: 3, + collate: 4, + as_: -1, exists: 0, - _smallest: -1000, - _largest: 1000 + + _asbool: -10, + _smallest: _smallest, + _largest: _largest } + def is_precedent(operator, against): - return (_PRECEDENCE.get(operator, _PRECEDENCE[_smallest]) <= - _PRECEDENCE.get(against, _PRECEDENCE[_largest])) + if operator is against and is_natural_self_precedent(operator): + return False + else: + return (_PRECEDENCE.get(operator, + getattr(operator, 'precedence', _smallest)) <= + _PRECEDENCE.get(against, + getattr(against, 'precedence', _largest))) diff --git a/sqlalchemy/sql/util.py b/sqlalchemy/sql/util.py index d5575e0..281d5f6 100644 --- a/sqlalchemy/sql/util.py +++ b/sqlalchemy/sql/util.py @@ -1,44 +1,54 @@ -from sqlalchemy import exc, schema, topological, util, sql, types as sqltypes -from sqlalchemy.sql import expression, operators, visitors +# sql/util.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""High level utilities which build upon other modules here. + +""" + +from .. import exc, util +from .base import _from_objects, ColumnSet +from . import operators, visitors from itertools import chain +from collections import deque -"""Utility functions that build upon SQL and Schema constructs.""" +from .elements import BindParameter, ColumnClause, ColumnElement, \ + Null, UnaryExpression, literal_column, Label, _label_reference, \ + _textual_label_reference +from .selectable import ScalarSelect, Join, FromClause, FromGrouping +from .schema import Column -def sort_tables(tables): - """sort a collection of Table objects in order of their foreign-key dependency.""" - - tables = list(tables) - tuples = [] - def visit_foreign_key(fkey): - if fkey.use_alter: - return - parent_table = fkey.column.table - if parent_table in tables: - child_table = fkey.parent.table - tuples.append( ( parent_table, child_table ) ) +join_condition = util.langhelpers.public_factory( + Join._join_condition, + ".sql.util.join_condition") + +# names that are still being imported from the outside +from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate +from .elements import _find_columns +from .ddl import sort_tables - for table in tables: - visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key}) - return topological.sort(tuples, tables) def find_join_source(clauses, join_to): - """Given a list of FROM clauses and a selectable, - return the first index and element from the list of - clauses which can be joined against the selectable. returns + """Given a list of FROM clauses and a selectable, + return the first index and element from the list of + clauses which can be joined against the selectable. returns None, None if no match is found. - + e.g.:: - + clause1 = table1.join(table2) clause2 = table4.join(table5) - + join_to = table2.join(table3) - + find_join_source([clause1, clause2], join_to) == clause1 - + """ - - selectables = list(expression._from_objects(join_to)) + + selectables = list(_from_objects(join_to)) for i, f in enumerate(clauses): for s in selectables: if f.is_derived_from(s): @@ -46,29 +56,88 @@ def find_join_source(clauses, join_to): else: return None, None - - -def find_tables(clause, check_columns=False, - include_aliases=False, include_joins=False, + +def visit_binary_product(fn, expr): + """Produce a traversal of the given expression, delivering + column comparisons to the given function. + + The function is of the form:: + + def my_fn(binary, left, right) + + For each binary expression located which has a + comparison operator, the product of "left" and + "right" will be delivered to that function, + in terms of that binary. + + Hence an expression like:: + + and_( + (a + b) == q + func.sum(e + f), + j == r + ) + + would have the traversal:: + + a q + a e + a f + b q + b e + b f + j r + + That is, every combination of "left" and + "right" that doesn't further contain + a binary comparison is passed as pairs. + + """ + stack = [] + + def visit(element): + if isinstance(element, ScalarSelect): + # we don't want to dig into correlated subqueries, + # those are just column elements by themselves + yield element + elif element.__visit_name__ == 'binary' and \ + operators.is_comparison(element.operator): + stack.insert(0, element) + for l in visit(element.left): + for r in visit(element.right): + fn(stack[0], l, r) + stack.pop(0) + for elem in element.get_children(): + visit(elem) + else: + if isinstance(element, ColumnClause): + yield element + for elem in element.get_children(): + for e in visit(elem): + yield e + list(visit(expr)) + + +def find_tables(clause, check_columns=False, + include_aliases=False, include_joins=False, include_selects=False, include_crud=False): """locate Table objects within the given expression.""" - + tables = [] _visitors = {} - + if include_selects: _visitors['select'] = _visitors['compound_select'] = tables.append - + if include_joins: _visitors['join'] = tables.append - + if include_aliases: - _visitors['alias'] = tables.append - + _visitors['alias'] = tables.append + if include_crud: _visitors['insert'] = _visitors['update'] = \ - _visitors['delete'] = lambda ent: tables.append(ent.table) - + _visitors['delete'] = lambda ent: tables.append(ent.table) + if check_columns: def visit_column(column): tables.append(column.table) @@ -76,271 +145,308 @@ def find_tables(clause, check_columns=False, _visitors['table'] = tables.append - visitors.traverse(clause, {'column_collections':False}, _visitors) + visitors.traverse(clause, {'column_collections': False}, _visitors) return tables -def find_columns(clause): - """locate Column objects within the given expression.""" - + +def unwrap_order_by(clause): + """Break up an 'order by' expression into individual column-expressions, + without DESC/ASC/NULLS FIRST/NULLS LAST""" + cols = util.column_set() - visitors.traverse(clause, {}, {'column':cols.add}) - return cols + result = [] + stack = deque([clause]) + while stack: + t = stack.popleft() + if isinstance(t, ColumnElement) and \ + ( + not isinstance(t, UnaryExpression) or + not operators.is_ordering_modifier(t.modifier) + ): + if isinstance(t, _label_reference): + t = t.element + if isinstance(t, (_textual_label_reference)): + continue + if t not in cols: + cols.add(t) + result.append(t) + else: + for c in t.get_children(): + stack.append(c) + return result + + +def unwrap_label_reference(element): + def replace(elem): + if isinstance(elem, (_label_reference, _textual_label_reference)): + return elem.element + + return visitors.replacement_traverse( + element, {}, replace + ) + + +def expand_column_list_from_order_by(collist, order_by): + """Given the columns clause and ORDER BY of a selectable, + return a list of column expressions that can be added to the collist + corresponding to the ORDER BY, without repeating those already + in the collist. + + """ + cols_already_present = set([ + col.element if col._order_by_label_element is not None + else col for col in collist + ]) + + return [ + col for col in + chain(*[ + unwrap_order_by(o) + for o in order_by + ]) + if col not in cols_already_present + ] + + +def clause_is_present(clause, search): + """Given a target clause and a second to search within, return True + if the target is plainly present in the search without any + subqueries or aliases involved. + + Basically descends through Joins. + + """ + + for elem in surface_selectables(search): + if clause == elem: # use == here so that Annotated's compare + return True + else: + return False + + +def surface_selectables(clause): + stack = [clause] + while stack: + elem = stack.pop() + yield elem + if isinstance(elem, Join): + stack.extend((elem.left, elem.right)) + elif isinstance(elem, FromGrouping): + stack.append(elem.element) + + +def surface_column_elements(clause): + """traverse and yield only outer-exposed column elements, such as would + be addressable in the WHERE clause of a SELECT if this element were + in the columns clause.""" + + stack = deque([clause]) + while stack: + elem = stack.popleft() + yield elem + for sub in elem.get_children(): + if isinstance(sub, FromGrouping): + continue + stack.append(sub) + + +def selectables_overlap(left, right): + """Return True if left/right have some overlapping selectable""" + + return bool( + set(surface_selectables(left)).intersection( + surface_selectables(right) + ) + ) + + +def bind_values(clause): + """Return an ordered list of "bound" values in the given clause. + + E.g.:: + + >>> expr = and_( + ... table.c.foo==5, table.c.foo==7 + ... ) + >>> bind_values(expr) + [5, 7] + """ + + v = [] + + def visit_bindparam(bind): + v.append(bind.effective_value) + + visitors.traverse(clause, {}, {'bindparam': visit_bindparam}) + return v + def _quote_ddl_expr(element): - if isinstance(element, basestring): + if isinstance(element, util.string_types): element = element.replace("'", "''") return "'%s'" % element else: return repr(element) - -def expression_as_ddl(clause): - """Given a SQL expression, convert for usage in DDL, such as - CREATE INDEX and CHECK CONSTRAINT. - - Converts bind params into quoted literals, column identifiers - into detached column constructs so that the parent table - identifier is not included. - + + +class _repr_base(object): + _LIST = 0 + _TUPLE = 1 + _DICT = 2 + + __slots__ = 'max_chars', + + def trunc(self, value): + rep = repr(value) + lenrep = len(rep) + if lenrep > self.max_chars: + segment_length = self.max_chars // 2 + rep = ( + rep[0:segment_length] + + (" ... (%d characters truncated) ... " + % (lenrep - self.max_chars)) + + rep[-segment_length:] + ) + return rep + + +class _repr_row(_repr_base): + """Provide a string view of a row.""" + + __slots__ = 'row', + + def __init__(self, row, max_chars=300): + self.row = row + self.max_chars = max_chars + + def __repr__(self): + trunc = self.trunc + return "(%s%s)" % ( + ", ".join(trunc(value) for value in self.row), + "," if len(self.row) == 1 else "" + ) + + +class _repr_params(_repr_base): + """Provide a string view of bound parameters. + + Truncates display to a given numnber of 'multi' parameter sets, + as well as long values to a given number of characters. + """ - def repl(element): - if isinstance(element, expression._BindParamClause): - return expression.literal_column(_quote_ddl_expr(element.value)) - elif isinstance(element, expression.ColumnClause) and \ - element.table is not None: - return expression.column(element.name) + + __slots__ = 'params', 'batches', + + def __init__(self, params, batches, max_chars=300): + self.params = params + self.batches = batches + self.max_chars = max_chars + + def __repr__(self): + if isinstance(self.params, list): + typ = self._LIST + ismulti = self.params and isinstance( + self.params[0], (list, dict, tuple)) + elif isinstance(self.params, tuple): + typ = self._TUPLE + ismulti = self.params and isinstance( + self.params[0], (list, dict, tuple)) + elif isinstance(self.params, dict): + typ = self._DICT + ismulti = False else: - return None - - return visitors.replacement_traverse(clause, {}, repl) - + return self.trunc(self.params) + + if ismulti and len(self.params) > self.batches: + msg = " ... displaying %i of %i total bound parameter sets ... " + return ' '.join(( + self._repr_multi(self.params[:self.batches - 2], typ)[0:-1], + msg % (self.batches, len(self.params)), + self._repr_multi(self.params[-2:], typ)[1:] + )) + elif ismulti: + return self._repr_multi(self.params, typ) + else: + return self._repr_params(self.params, typ) + + def _repr_multi(self, multi_params, typ): + if multi_params: + if isinstance(multi_params[0], list): + elem_type = self._LIST + elif isinstance(multi_params[0], tuple): + elem_type = self._TUPLE + elif isinstance(multi_params[0], dict): + elem_type = self._DICT + else: + assert False, \ + "Unknown parameter type %s" % (type(multi_params[0])) + + elements = ", ".join( + self._repr_params(params, elem_type) + for params in multi_params) + else: + elements = "" + + if typ == self._LIST: + return "[%s]" % elements + else: + return "(%s)" % elements + + def _repr_params(self, params, typ): + trunc = self.trunc + if typ is self._DICT: + return "{%s}" % ( + ", ".join( + "%r: %s" % (key, trunc(value)) + for key, value in params.items() + ) + ) + elif typ is self._TUPLE: + return "(%s%s)" % ( + ", ".join(trunc(value) for value in params), + "," if len(params) == 1 else "" + + ) + else: + return "[%s]" % ( + ", ".join(trunc(value) for value in params) + ) + + def adapt_criterion_to_null(crit, nulls): - """given criterion containing bind params, convert selected elements to IS NULL.""" + """given criterion containing bind params, convert selected elements + to IS NULL. + + """ def visit_binary(binary): - if isinstance(binary.left, expression._BindParamClause) and binary.left.key in nulls: + if isinstance(binary.left, BindParameter) \ + and binary.left._identifying_key in nulls: # reverse order if the NULL is on the left side binary.left = binary.right - binary.right = expression.null() + binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot - elif isinstance(binary.right, expression._BindParamClause) and binary.right.key in nulls: - binary.right = expression.null() + elif isinstance(binary.right, BindParameter) \ + and binary.right._identifying_key in nulls: + binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot - return visitors.cloned_traverse(crit, {}, {'binary':visit_binary}) - - -def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None): - """create a join condition between two tables or selectables. - - e.g.:: - - join_condition(tablea, tableb) - - would produce an expression along the lines of:: - - tablea.c.id==tableb.c.tablea_id - - The join is determined based on the foreign key relationships - between the two selectables. If there are multiple ways - to join, or no way to join, an error is raised. - - :param ignore_nonexistent_tables: This flag will cause the - function to silently skip over foreign key resolution errors - due to nonexistent tables - the assumption is that these - tables have not yet been defined within an initialization process - and are not significant to the operation. - - :param a_subset: An optional expression that is a sub-component - of ``a``. An attempt will be made to join to just this sub-component - first before looking at the full ``a`` construct, and if found - will be successful even if there are other ways to join to ``a``. - This allows the "right side" of a join to be passed thereby - providing a "natural join". - - """ - crit = [] - constraints = set() - - for left in (a_subset, a): - if left is None: - continue - for fk in b.foreign_keys: - try: - col = fk.get_referent(left) - except exc.NoReferencedTableError: - if ignore_nonexistent_tables: - continue - else: - raise - - if col is not None: - crit.append(col == fk.parent) - constraints.add(fk.constraint) - if left is not b: - for fk in left.foreign_keys: - try: - col = fk.get_referent(b) - except exc.NoReferencedTableError: - if ignore_nonexistent_tables: - continue - else: - raise - - if col is not None: - crit.append(col == fk.parent) - constraints.add(fk.constraint) - if crit: - break - - if len(crit) == 0: - if isinstance(b, expression._FromGrouping): - hint = " Perhaps you meant to convert the right side to a subquery using alias()?" - else: - hint = "" - raise exc.ArgumentError( - "Can't find any foreign key relationships " - "between '%s' and '%s'.%s" % (a.description, b.description, hint)) - elif len(constraints) > 1: - raise exc.ArgumentError( - "Can't determine join between '%s' and '%s'; " - "tables have more than one foreign key " - "constraint relationship between them. " - "Please specify the 'onclause' of this " - "join explicitly." % (a.description, b.description)) - elif len(crit) == 1: - return (crit[0]) - else: - return sql.and_(*crit) - - -class Annotated(object): - """clones a ClauseElement and applies an 'annotations' dictionary. - - Unlike regular clones, this clone also mimics __hash__() and - __cmp__() of the original element so that it takes its place - in hashed collections. - - A reference to the original element is maintained, for the important - reason of keeping its hash value current. When GC'ed, the - hash value may be reused, causing conflicts. - - """ - - def __new__(cls, *args): - if not args: - # clone constructor - return object.__new__(cls) - else: - element, values = args - # pull appropriate subclass from registry of annotated - # classes - try: - cls = annotated_classes[element.__class__] - except KeyError: - cls = annotated_classes[element.__class__] = type.__new__(type, - "Annotated%s" % element.__class__.__name__, - (Annotated, element.__class__), {}) - return object.__new__(cls) - - def __init__(self, element, values): - # force FromClause to generate their internal - # collections into __dict__ - if isinstance(element, expression.FromClause): - element.c - - self.__dict__ = element.__dict__.copy() - self.__element = element - self._annotations = values - - def _annotate(self, values): - _values = self._annotations.copy() - _values.update(values) - clone = self.__class__.__new__(self.__class__) - clone.__dict__ = self.__dict__.copy() - clone._annotations = _values - return clone - - def _deannotate(self): - return self.__element - - def _clone(self): - clone = self.__element._clone() - if clone is self.__element: - # detect immutable, don't change anything - return self - else: - # update the clone with any changes that have occured - # to this object's __dict__. - clone.__dict__.update(self.__dict__) - return Annotated(clone, self._annotations) - - def __hash__(self): - return hash(self.__element) - - def __cmp__(self, other): - return cmp(hash(self.__element), hash(other)) - -# hard-generate Annotated subclasses. this technique -# is used instead of on-the-fly types (i.e. type.__new__()) -# so that the resulting objects are pickleable. -annotated_classes = {} - -from sqlalchemy.sql import expression -for cls in expression.__dict__.values() + [schema.Column, schema.Table]: - if isinstance(cls, type) and issubclass(cls, expression.ClauseElement): - exec "class Annotated%s(Annotated, cls):\n" \ - " __visit_name__ = cls.__visit_name__\n"\ - " pass" % (cls.__name__, ) in locals() - exec "annotated_classes[cls] = Annotated%s" % (cls.__name__) - -def _deep_annotate(element, annotations, exclude=None): - """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary. - - Elements within the exclude collection will be cloned but not annotated. - - """ - def clone(elem): - # check if element is present in the exclude list. - # take into account proxying relationships. - if exclude and \ - hasattr(elem, 'proxy_set') and \ - elem.proxy_set.intersection(exclude): - elem = elem._clone() - elif annotations != elem._annotations: - elem = elem._annotate(annotations.copy()) - elem._copy_internals(clone=clone) - return elem - - if element is not None: - element = clone(element) - return element - -def _deep_deannotate(element): - """Deep copy the given element, removing all annotations.""" - - def clone(elem): - elem = elem._deannotate() - elem._copy_internals(clone=clone) - return elem - - if element is not None: - element = clone(element) - return element + return visitors.cloned_traverse(crit, {}, {'binary': visit_binary}) def splice_joins(left, right, stop_on=None): if left is None: return right - + stack = [(right, None)] adapter = ClauseAdapter(left) ret = None while stack: (right, prevright) = stack.pop() - if isinstance(right, expression.Join) and right is not stop_on: + if isinstance(right, Join) and right is not stop_on: right = right._clone() right._reset_exported() right.onclause = adapter.traverse(right.onclause) @@ -353,27 +459,31 @@ def splice_joins(left, right, stop_on=None): ret = right return ret - + + def reduce_columns(columns, *clauses, **kw): - """given a list of columns, return a 'reduced' set based on natural equivalents. + r"""given a list of columns, return a 'reduced' set based on natural + equivalents. the set is reduced to the smallest list of columns which have no natural - equivalent present in the list. A "natural equivalent" means that two columns - will ultimately represent the same value because they are related by a foreign key. + equivalent present in the list. A "natural equivalent" means that two + columns will ultimately represent the same value because they are related + by a foreign key. \*clauses is an optional list of join clauses which will be traversed to further identify columns that are "equivalent". \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys - whose tables are not yet configured. - - This function is primarily used to determine the most minimal "primary key" - from a selectable, by reducing the set of primary key columns present - in the the selectable to just those that are not repeated. + whose tables are not yet configured, or columns that aren't yet present. + + This function is primarily used to determine the most minimal "primary + key" from a selectable, by reducing the set of primary key columns present + in the selectable to just those that are not repeated. """ ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False) - + only_synonyms = kw.pop('only_synonyms', False) + columns = util.ordered_column_set(columns) omit = util.column_set() @@ -384,149 +494,98 @@ def reduce_columns(columns, *clauses, **kw): continue try: fk_col = fk.column - except exc.NoReferencedTableError: + except exc.NoReferencedColumnError: + # TODO: add specific coverage here + # to test/sql/test_selectable ReduceTest if ignore_nonexistent_tables: continue else: raise - if fk_col.shares_lineage(c): + except exc.NoReferencedTableError: + # TODO: add specific coverage here + # to test/sql/test_selectable ReduceTest + if ignore_nonexistent_tables: + continue + else: + raise + if fk_col.shares_lineage(c) and \ + (not only_synonyms or + c.name == col.name): omit.add(col) break if clauses: def visit_binary(binary): if binary.operator == operators.eq: - cols = util.column_set(chain(*[c.proxy_set for c in columns.difference(omit)])) + cols = util.column_set( + chain(*[c.proxy_set for c in columns.difference(omit)])) if binary.left in cols and binary.right in cols: - for c in columns: - if c.shares_lineage(binary.right): + for c in reversed(columns): + if c.shares_lineage(binary.right) and \ + (not only_synonyms or + c.name == binary.left.name): omit.add(c) break for clause in clauses: - visitors.traverse(clause, {}, {'binary':visit_binary}) + if clause is not None: + visitors.traverse(clause, {}, {'binary': visit_binary}) - return expression.ColumnSet(columns.difference(omit)) + return ColumnSet(columns.difference(omit)) -def criterion_as_pairs(expression, consider_as_foreign_keys=None, - consider_as_referenced_keys=None, any_operator=False): + +def criterion_as_pairs(expression, consider_as_foreign_keys=None, + consider_as_referenced_keys=None, any_operator=False): """traverse an expression and locate binary criterion pairs.""" - + if consider_as_foreign_keys and consider_as_referenced_keys: raise exc.ArgumentError("Can only specify one of " "'consider_as_foreign_keys' or " "'consider_as_referenced_keys'") - + + def col_is(a, b): + # return a is b + return a.compare(b) + def visit_binary(binary): if not any_operator and binary.operator is not operators.eq: return - if not isinstance(binary.left, sql.ColumnElement) or \ - not isinstance(binary.right, sql.ColumnElement): + if not isinstance(binary.left, ColumnElement) or \ + not isinstance(binary.right, ColumnElement): return if consider_as_foreign_keys: if binary.left in consider_as_foreign_keys and \ - (binary.right is binary.left or - binary.right not in consider_as_foreign_keys): + (col_is(binary.right, binary.left) or + binary.right not in consider_as_foreign_keys): pairs.append((binary.right, binary.left)) elif binary.right in consider_as_foreign_keys and \ - (binary.left is binary.right or - binary.left not in consider_as_foreign_keys): + (col_is(binary.left, binary.right) or + binary.left not in consider_as_foreign_keys): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: if binary.left in consider_as_referenced_keys and \ - (binary.right is binary.left or - binary.right not in consider_as_referenced_keys): + (col_is(binary.right, binary.left) or + binary.right not in consider_as_referenced_keys): pairs.append((binary.left, binary.right)) elif binary.right in consider_as_referenced_keys and \ - (binary.left is binary.right or - binary.left not in consider_as_referenced_keys): + (col_is(binary.left, binary.right) or + binary.left not in consider_as_referenced_keys): pairs.append((binary.right, binary.left)) else: - if isinstance(binary.left, schema.Column) and \ - isinstance(binary.right, schema.Column): + if isinstance(binary.left, Column) and \ + isinstance(binary.right, Column): if binary.left.references(binary.right): pairs.append((binary.right, binary.left)) elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) pairs = [] - visitors.traverse(expression, {}, {'binary':visit_binary}) + visitors.traverse(expression, {}, {'binary': visit_binary}) return pairs -def folded_equivalents(join, equivs=None): - """Return a list of uniquely named columns. - - The column list of the given Join will be narrowed - down to a list of all equivalently-named, - equated columns folded into one column, where 'equated' means they are - equated to each other in the ON clause of this join. - - This function is used by Join.select(fold_equivalents=True). - - Deprecated. This function is used for a certain kind of - "polymorphic_union" which is designed to achieve joined - table inheritance where the base table has no "discriminator" - column; [ticket:1131] will provide a better way to - achieve this. - - """ - if equivs is None: - equivs = set() - def visit_binary(binary): - if binary.operator == operators.eq and binary.left.name == binary.right.name: - equivs.add(binary.right) - equivs.add(binary.left) - visitors.traverse(join.onclause, {}, {'binary':visit_binary}) - collist = [] - if isinstance(join.left, expression.Join): - left = folded_equivalents(join.left, equivs) - else: - left = list(join.left.columns) - if isinstance(join.right, expression.Join): - right = folded_equivalents(join.right, equivs) - else: - right = list(join.right.columns) - used = set() - for c in left + right: - if c in equivs: - if c.name not in used: - collist.append(c) - used.add(c.name) - else: - collist.append(c) - return collist - -class AliasedRow(object): - """Wrap a RowProxy with a translation map. - - This object allows a set of keys to be translated - to those present in a RowProxy. - - """ - def __init__(self, row, map): - # AliasedRow objects don't nest, so un-nest - # if another AliasedRow was passed - if isinstance(row, AliasedRow): - self.row = row.row - else: - self.row = row - self.map = map - - def __contains__(self, key): - return self.map[key] in self.row - - def has_key(self, key): - return key in self - - def __getitem__(self, key): - return self.row[self.map[key]] - - def keys(self): - return self.row.keys() - class ClauseAdapter(visitors.ReplacingCloningVisitor): """Clones and modifies clauses based on column correspondence. - + E.g.:: table1 = Table('sometable', metadata, @@ -550,102 +609,154 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): s.c.col1 == table2.c.col1 """ - def __init__(self, selectable, equivalents=None, include=None, exclude=None): - self.__traverse_options__ = {'column_collections':False, 'stop_on':[selectable]} - self.selectable = selectable - self.include = include - self.exclude = exclude - self.equivalents = util.column_dict(equivalents or {}) - - def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET): - newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded) + def __init__(self, selectable, equivalents=None, + include_fn=None, exclude_fn=None, + adapt_on_names=False, anonymize_labels=False): + self.__traverse_options__ = { + 'stop_on': [selectable], + 'anonymize_labels': anonymize_labels} + self.selectable = selectable + self.include_fn = include_fn + self.exclude_fn = exclude_fn + self.equivalents = util.column_dict(equivalents or {}) + self.adapt_on_names = adapt_on_names + + def _corresponding_column(self, col, require_embedded, + _seen=util.EMPTY_SET): + newcol = self.selectable.corresponding_column( + col, + require_embedded=require_embedded) if newcol is None and col in self.equivalents and col not in _seen: for equiv in self.equivalents[col]: - newcol = self._corresponding_column(equiv, require_embedded=require_embedded, _seen=_seen.union([col])) + newcol = self._corresponding_column( + equiv, require_embedded=require_embedded, + _seen=_seen.union([col])) if newcol is not None: return newcol + if self.adapt_on_names and newcol is None: + newcol = self.selectable.c.get(col.name) return newcol def replace(self, col): - if isinstance(col, expression.FromClause): - if self.selectable.is_derived_from(col): - return self.selectable + if isinstance(col, FromClause) and \ + self.selectable.is_derived_from(col): + return self.selectable + elif not isinstance(col, ColumnElement): + return None + elif self.include_fn and not self.include_fn(col): + return None + elif self.exclude_fn and self.exclude_fn(col): + return None + else: + return self._corresponding_column(col, True) - if not isinstance(col, expression.ColumnElement): - return None - - if self.include and col not in self.include: - return None - elif self.exclude and col in self.exclude: - return None - - return self._corresponding_column(col, True) class ColumnAdapter(ClauseAdapter): """Extends ClauseAdapter with extra utility functions. - - Provides the ability to "wrap" this ClauseAdapter - around another, a columns dictionary which returns - adapted elements given an original, and an - adapted_row() factory. - + + Key aspects of ColumnAdapter include: + + * Expressions that are adapted are stored in a persistent + .columns collection; so that an expression E adapted into + an expression E1, will return the same object E1 when adapted + a second time. This is important in particular for things like + Label objects that are anonymized, so that the ColumnAdapter can + be used to present a consistent "adapted" view of things. + + * Exclusion of items from the persistent collection based on + include/exclude rules, but also independent of hash identity. + This because "annotated" items all have the same hash identity as their + parent. + + * "wrapping" capability is added, so that the replacement of an expression + E can proceed through a series of adapters. This differs from the + visitor's "chaining" feature in that the resulting object is passed + through all replacing functions unconditionally, rather than stopping + at the first one that returns non-None. + + * An adapt_required option, used by eager loading to indicate that + We don't trust a result row column that is not translated. + This is to prevent a column from being interpreted as that + of the child row in a self-referential scenario, see + inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency + """ - def __init__(self, selectable, equivalents=None, - chain_to=None, include=None, - exclude=None, adapt_required=False): - ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) + + def __init__(self, selectable, equivalents=None, + chain_to=None, adapt_required=False, + include_fn=None, exclude_fn=None, + adapt_on_names=False, + allow_label_resolve=True, + anonymize_labels=False): + ClauseAdapter.__init__(self, selectable, equivalents, + include_fn=include_fn, exclude_fn=exclude_fn, + adapt_on_names=adapt_on_names, + anonymize_labels=anonymize_labels) + if chain_to: self.chain(chain_to) self.columns = util.populate_column_dict(self._locate_col) + if self.include_fn or self.exclude_fn: + self.columns = self._IncludeExcludeMapping(self, self.columns) self.adapt_required = adapt_required + self.allow_label_resolve = allow_label_resolve + self._wrap = None + + class _IncludeExcludeMapping(object): + def __init__(self, parent, columns): + self.parent = parent + self.columns = columns + + def __getitem__(self, key): + if ( + self.parent.include_fn and not self.parent.include_fn(key) + ) or ( + self.parent.exclude_fn and self.parent.exclude_fn(key) + ): + if self.parent._wrap: + return self.parent._wrap.columns[key] + else: + return key + return self.columns[key] def wrap(self, adapter): ac = self.__class__.__new__(self.__class__) - ac.__dict__ = self.__dict__.copy() - ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col) - ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause) - ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list) + ac.__dict__.update(self.__dict__) + ac._wrap = adapter ac.columns = util.populate_column_dict(ac._locate_col) + if ac.include_fn or ac.exclude_fn: + ac.columns = self._IncludeExcludeMapping(ac, ac.columns) + return ac - adapt_clause = ClauseAdapter.traverse + def traverse(self, obj): + return self.columns[obj] + + adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process - def _wrap(self, local, wrapped): - def locate(col): - col = local(col) - return wrapped(col) - return locate - def _locate_col(self, col): - c = self._corresponding_column(col, True) - if c is None: - c = self.adapt_clause(col) - - # anonymize labels in case they have a hardcoded name - if isinstance(c, expression._Label): - c = c.label(None) - - # adapt_required indicates that if we got the same column - # back which we put in (i.e. it passed through), - # it's not correct. this is used by eagerloading which - # knows that all columns and expressions need to be adapted - # to a result row, and a "passthrough" is definitely targeting - # the wrong column. + + c = ClauseAdapter.traverse(self, col) + + if self._wrap: + c2 = self._wrap._locate_col(c) + if c2 is not None: + c = c2 + if self.adapt_required and c is col: return None - - return c - def adapted_row(self, row): - return AliasedRow(row, self.columns) - + c._allow_label_resolve = self.allow_label_resolve + + return c + def __getstate__(self): d = self.__dict__.copy() del d['columns'] return d - + def __setstate__(self, state): self.__dict__.update(state) self.columns = util.PopulateDict(self._locate_col) diff --git a/sqlalchemy/sql/visitors.py b/sqlalchemy/sql/visitors.py index 4a54375..7f09518 100644 --- a/sqlalchemy/sql/visitors.py +++ b/sqlalchemy/sql/visitors.py @@ -1,90 +1,137 @@ +# sql/visitors.py +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + """Visitor/traversal interface and library functions. SQLAlchemy schema and expression constructs rely on a Python-centric version of the classic "visitor" pattern as the primary way in which -they apply functionality. The most common use of this pattern -is statement compilation, where individual expression classes match -up to rendering methods that produce a string result. Beyond this, -the visitor system is also used to inspect expressions for various -information and patterns, as well as for usage in +they apply functionality. The most common use of this pattern +is statement compilation, where individual expression classes match +up to rendering methods that produce a string result. Beyond this, +the visitor system is also used to inspect expressions for various +information and patterns, as well as for usage in some kinds of expression transformation. Other kinds of transformation use a non-visitor traversal system. -For many examples of how the visit system is used, see the +For many examples of how the visit system is used, see the sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules. For an introduction to clause adaption, see -http://techspot.zzzeek.org/?p=19 . +http://techspot.zzzeek.org/2008/01/23/expression-transformations/ """ from collections import deque -import re -from sqlalchemy import util +from .. import util import operator +from .. import exc + +__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor', + 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate', + 'iterate_depthfirst', 'traverse_using', 'traverse', + 'traverse_depthfirst', + 'cloned_traverse', 'replacement_traverse'] + -__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor', - 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate', - 'iterate_depthfirst', 'traverse_using', 'traverse', - 'cloned_traverse', 'replacement_traverse'] - class VisitableType(type): - """Metaclass which checks for a `__visit_name__` attribute and - applies `_compiler_dispatch` method to classes. - + """Metaclass which assigns a `_compiler_dispatch` method to classes + having a `__visit_name__` attribute. + + The _compiler_dispatch attribute becomes an instance method which + looks approximately like the following:: + + def _compiler_dispatch (self, visitor, **kw): + '''Look for an attribute named "visit_" + self.__visit_name__ + on the visitor, and call it with the same kw params.''' + visit_attr = 'visit_%s' % self.__visit_name__ + return getattr(visitor, visit_attr)(self, **kw) + + Classes having no __visit_name__ attribute will remain unaffected. """ - + def __init__(cls, clsname, bases, clsdict): - if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'): - super(VisitableType, cls).__init__(clsname, bases, clsdict) - return - - # set up an optimized visit dispatch function - # for use by the compiler - visit_name = cls.__visit_name__ - if isinstance(visit_name, str): - getter = operator.attrgetter("visit_%s" % visit_name) - def _compiler_dispatch(self, visitor, **kw): - return getter(visitor)(self, **kw) - else: - def _compiler_dispatch(self, visitor, **kw): - return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) - - cls._compiler_dispatch = _compiler_dispatch - + if clsname != 'Visitable' and \ + hasattr(cls, '__visit_name__'): + _generate_dispatch(cls) + super(VisitableType, cls).__init__(clsname, bases, clsdict) -class Visitable(object): + +def _generate_dispatch(cls): + """Return an optimized visit dispatch function for the cls + for use by the compiler. + """ + if '__visit_name__' in cls.__dict__: + visit_name = cls.__visit_name__ + if isinstance(visit_name, str): + # There is an optimization opportunity here because the + # the string name of the class's __visit_name__ is known at + # this early stage (import time) so it can be pre-constructed. + getter = operator.attrgetter("visit_%s" % visit_name) + + def _compiler_dispatch(self, visitor, **kw): + try: + meth = getter(visitor) + except AttributeError: + raise exc.UnsupportedCompilationError(visitor, cls) + else: + return meth(self, **kw) + else: + # The optimization opportunity is lost for this case because the + # __visit_name__ is not yet a string. As a result, the visit + # string has to be recalculated with each compilation. + def _compiler_dispatch(self, visitor, **kw): + visit_attr = 'visit_%s' % self.__visit_name__ + try: + meth = getattr(visitor, visit_attr) + except AttributeError: + raise exc.UnsupportedCompilationError(visitor, cls) + else: + return meth(self, **kw) + + _compiler_dispatch.__doc__ = \ + """Look for an attribute named "visit_" + self.__visit_name__ + on the visitor, and call it with the same kw params. + """ + cls._compiler_dispatch = _compiler_dispatch + + +class Visitable(util.with_metaclass(VisitableType, object)): """Base class for visitable objects, applies the ``VisitableType`` metaclass. - + """ - __metaclass__ = VisitableType class ClauseVisitor(object): - """Base class for visitor objects which can traverse using + """Base class for visitor objects which can traverse using the traverse() function. - + """ - + __traverse_options__ = {} - - def traverse_single(self, obj): + + def traverse_single(self, obj, **kw): for v in self._visitor_iterator: meth = getattr(v, "visit_%s" % obj.__visit_name__, None) if meth: - return meth(obj) - - def iterate(self, obj): - """traverse the given expression structure, returning an iterator of all elements.""" + return meth(obj, **kw) + def iterate(self, obj): + """traverse the given expression structure, returning an iterator + of all elements. + + """ return iterate(obj, self.__traverse_options__) - + def traverse(self, obj): """traverse and visit the given expression structure.""" return traverse(obj, self.__traverse_options__, self._visitor_dict) - + @util.memoized_property def _visitor_dict(self): visitors = {} @@ -93,11 +140,11 @@ class ClauseVisitor(object): if name.startswith('visit_'): visitors[name[6:]] = getattr(self, name) return visitors - + @property def _visitor_iterator(self): """iterate through this visitor and each 'chained' visitor.""" - + v = self while v: yield v @@ -105,41 +152,46 @@ class ClauseVisitor(object): def chain(self, visitor): """'chain' an additional ClauseVisitor onto this ClauseVisitor. - + the chained visitor will receive all visit events after this one. - + """ tail = list(self._visitor_iterator)[-1] tail._next = visitor return self + class CloningVisitor(ClauseVisitor): - """Base class for visitor objects which can traverse using + """Base class for visitor objects which can traverse using the cloned_traverse() function. - + """ def copy_and_process(self, list_): - """Apply cloned traversal to the given list of elements, and return the new list.""" + """Apply cloned traversal to the given list of elements, and return + the new list. + """ return [self.traverse(x) for x in list_] def traverse(self, obj): """traverse and visit the given expression structure.""" - return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict) + return cloned_traverse( + obj, self.__traverse_options__, self._visitor_dict) + class ReplacingCloningVisitor(CloningVisitor): - """Base class for visitor objects which can traverse using + """Base class for visitor objects which can traverse using the replacement_traverse() function. - + """ def replace(self, elem): """receive pre-copied elements during a cloning traversal. - - If the method returns a new element, the element is used - instead of creating a simple copy of the element. Traversal + + If the method returns a new element, the element is used + instead of creating a simple copy of the element. Traversal will halt on the newly returned element if it is re-encountered. """ return None @@ -154,25 +206,39 @@ class ReplacingCloningVisitor(CloningVisitor): return e return replacement_traverse(obj, self.__traverse_options__, replace) + def iterate(obj, opts): """traverse the given expression structure, returning an iterator. - + traversal is configured to be breadth-first. - + """ + # fasttrack for atomic elements like columns + children = obj.get_children(**opts) + if not children: + return [obj] + + traversal = deque() stack = deque([obj]) while stack: t = stack.popleft() - yield t + traversal.append(t) for c in t.get_children(**opts): stack.append(c) + return iter(traversal) + def iterate_depthfirst(obj, opts): """traverse the given expression structure, returning an iterator. - + traversal is configured to be depth-first. - + """ + # fasttrack for atomic elements like columns + children = obj.get_children(**opts) + if not children: + return [obj] + stack = deque([obj]) traversal = deque() while stack: @@ -182,75 +248,81 @@ def iterate_depthfirst(obj, opts): stack.append(c) return iter(traversal) -def traverse_using(iterator, obj, visitors): - """visit the given expression structure using the given iterator of objects.""" +def traverse_using(iterator, obj, visitors): + """visit the given expression structure using the given iterator of + objects. + + """ for target in iterator: meth = visitors.get(target.__visit_name__, None) if meth: meth(target) return obj - -def traverse(obj, opts, visitors): - """traverse and visit the given expression structure using the default iterator.""" + +def traverse(obj, opts, visitors): + """traverse and visit the given expression structure using the default + iterator. + + """ return traverse_using(iterate(obj, opts), obj, visitors) -def traverse_depthfirst(obj, opts, visitors): - """traverse and visit the given expression structure using the depth-first iterator.""" +def traverse_depthfirst(obj, opts, visitors): + """traverse and visit the given expression structure using the + depth-first iterator. + + """ return traverse_using(iterate_depthfirst(obj, opts), obj, visitors) + def cloned_traverse(obj, opts, visitors): - """clone the given expression structure, allowing modifications by visitors.""" - - cloned = util.column_dict() + """clone the given expression structure, allowing + modifications by visitors.""" - def clone(element): - if element not in cloned: - cloned[element] = element._clone() - return cloned[element] + cloned = {} + stop_on = set(opts.get('stop_on', [])) - obj = clone(obj) - stack = [obj] + def clone(elem): + if elem in stop_on: + return elem + else: + if id(elem) not in cloned: + cloned[id(elem)] = newelem = elem._clone() + newelem._copy_internals(clone=clone) + meth = visitors.get(newelem.__visit_name__, None) + if meth: + meth(newelem) + return cloned[id(elem)] - while stack: - t = stack.pop() - if t in cloned: - continue - t._copy_internals(clone=clone) - - meth = visitors.get(t.__visit_name__, None) - if meth: - meth(t) - - for c in t.get_children(**opts): - stack.append(c) + if obj is not None: + obj = clone(obj) return obj + def replacement_traverse(obj, opts, replace): - """clone the given expression structure, allowing element replacement by a given replacement function.""" - - cloned = util.column_dict() - stop_on = util.column_set(opts.get('stop_on', [])) + """clone the given expression structure, allowing element + replacement by a given replacement function.""" - def clone(element): - newelem = replace(element) - if newelem is not None: - stop_on.add(newelem) - return newelem + cloned = {} + stop_on = set([id(x) for x in opts.get('stop_on', [])]) - if element not in cloned: - cloned[element] = element._clone() - return cloned[element] + def clone(elem, **kw): + if id(elem) in stop_on or \ + 'no_replacement_traverse' in elem._annotations: + return elem + else: + newelem = replace(elem) + if newelem is not None: + stop_on.add(id(newelem)) + return newelem + else: + if elem not in cloned: + cloned[elem] = newelem = elem._clone() + newelem._copy_internals(clone=clone, **kw) + return cloned[elem] - obj = clone(obj) - stack = [obj] - while stack: - t = stack.pop() - if t in stop_on: - continue - t._copy_internals(clone=clone) - for c in t.get_children(**opts): - stack.append(c) + if obj is not None: + obj = clone(obj, **opts) return obj diff --git a/sqlalchemy/types.py b/sqlalchemy/types.py index 16cd57f..ea07b91 100644 --- a/sqlalchemy/types.py +++ b/sqlalchemy/types.py @@ -1,1742 +1,81 @@ # types.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com +# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors +# # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""defines genericized SQL types, each represented by a subclass of -:class:`~sqlalchemy.types.AbstractType`. Dialects define further subclasses of these -types. - -For more information see the SQLAlchemy documentation on types. +"""Compatibility namespace for sqlalchemy.sql.types. """ -__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType', - 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR','TEXT', 'Text', - 'FLOAT', 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', - 'BLOB', 'BOOLEAN', 'SMALLINT', 'INTEGER', 'DATE', 'TIME', - 'String', 'Integer', 'SmallInteger', 'BigInteger', 'Numeric', - 'Float', 'DateTime', 'Date', 'Time', 'LargeBinary', 'Binary', 'Boolean', - 'Unicode', 'MutableType', 'Concatenable', 'UnicodeText', - 'PickleType', 'Interval', 'type_map', 'Enum' ] - -import inspect -import datetime as dt -from decimal import Decimal as _python_Decimal -import codecs - -from sqlalchemy import exc, schema -from sqlalchemy.sql import expression, operators -import sys -schema.types = expression.sqltypes =sys.modules['sqlalchemy.types'] -from sqlalchemy.util import pickle -from sqlalchemy.sql.visitors import Visitable -from sqlalchemy import util -from sqlalchemy import processors -import collections - -NoneType = type(None) -if util.jython: - import array - -class AbstractType(Visitable): - - def __init__(self, *args, **kwargs): - pass - - def compile(self, dialect): - return dialect.type_compiler.process(self) - - def copy_value(self, value): - return value - - def bind_processor(self, dialect): - """Defines a bind parameter processing function. - - :param dialect: Dialect instance in use. - - """ - - return None - - def result_processor(self, dialect, coltype): - """Defines a result-column processing function. - - :param dialect: Dialect instance in use. - - :param coltype: DBAPI coltype argument received in cursor.description. - - """ - - return None - - def compare_values(self, x, y): - """Compare two values for equality.""" - - return x == y - - def is_mutable(self): - """Return True if the target Python type is 'mutable'. - - This allows systems like the ORM to know if a column value can - be considered 'not changed' by comparing the identity of - objects alone. - - Use the :class:`MutableType` mixin or override this method to - return True in custom types that hold mutable values such as - ``dict``, ``list`` and custom objects. - - """ - return False - - def get_dbapi_type(self, dbapi): - """Return the corresponding type object from the underlying DB-API, if - any. - - This can be useful for calling ``setinputsizes()``, for example. - - """ - return None - - def _adapt_expression(self, op, othertype): - """evaluate the return type of , - and apply any adaptations to the given operator. - - """ - return op, self - - @util.memoized_property - def _type_affinity(self): - """Return a rudimental 'affinity' value expressing the general class of type.""" - - typ = None - for t in self.__class__.__mro__: - if t is TypeEngine or t is UserDefinedType: - return typ - elif issubclass(t, TypeEngine): - typ = t - else: - return self.__class__ - - def _coerce_compared_value(self, op, value): - _coerced_type = type_map.get(type(value), NULLTYPE) - if _coerced_type is NULLTYPE or _coerced_type._type_affinity is self._type_affinity: - return self - else: - return _coerced_type - - def _compare_type_affinity(self, other): - return self._type_affinity is other._type_affinity - - def __repr__(self): - return "%s(%s)" % ( - self.__class__.__name__, - ", ".join("%s=%r" % (k, getattr(self, k, None)) - for k in inspect.getargspec(self.__init__)[0][1:])) - -class TypeEngine(AbstractType): - """Base for built-in types.""" - - @util.memoized_property - def _impl_dict(self): - return {} - - def dialect_impl(self, dialect, **kwargs): - key = (dialect.__class__, dialect.server_version_info) - - try: - return self._impl_dict[key] - except KeyError: - return self._impl_dict.setdefault(key, dialect.type_descriptor(self)) - - def __getstate__(self): - d = self.__dict__.copy() - d.pop('_impl_dict', None) - return d - - def bind_processor(self, dialect): - """Return a conversion function for processing bind values. - - Returns a callable which will receive a bind parameter value - as the sole positional argument and will return a value to - send to the DB-API. - - If processing is not necessary, the method should return ``None``. - - """ - return None - - def result_processor(self, dialect, coltype): - """Return a conversion function for processing result row values. - - Returns a callable which will receive a result row column - value as the sole positional argument and will return a value - to return to the user. - - If processing is not necessary, the method should return ``None``. - - """ - return None - - def adapt(self, cls): - return cls() - -class UserDefinedType(TypeEngine): - """Base for user defined types. - - This should be the base of new types. Note that - for most cases, :class:`TypeDecorator` is probably - more appropriate:: - - import sqlalchemy.types as types - - class MyType(types.UserDefinedType): - def __init__(self, precision = 8): - self.precision = precision - - def get_col_spec(self): - return "MYTYPE(%s)" % self.precision - - def bind_processor(self, dialect): - def process(value): - return value - return process - - def result_processor(self, dialect, coltype): - def process(value): - return value - return process - - Once the type is made, it's immediately usable:: - - table = Table('foo', meta, - Column('id', Integer, primary_key=True), - Column('data', MyType(16)) - ) - - """ - __visit_name__ = "user_defined" - - def _adapt_expression(self, op, othertype): - """evaluate the return type of , - and apply any adaptations to the given operator. - - """ - return self.adapt_operator(op), self - - def adapt_operator(self, op): - """A hook which allows the given operator to be adapted - to something new. - - See also UserDefinedType._adapt_expression(), an as-yet- - semi-public method with greater capability in this regard. - - """ - return op - -class TypeDecorator(AbstractType): - """Allows the creation of types which add additional functionality - to an existing type. - - This method is preferred to direct subclassing of SQLAlchemy's - built-in types as it ensures that all required functionality of - the underlying type is kept in place. - - Typical usage:: - - import sqlalchemy.types as types - - class MyType(types.TypeDecorator): - '''Prefixes Unicode values with "PREFIX:" on the way in and - strips it off on the way out. - ''' - - impl = types.Unicode - - def process_bind_param(self, value, dialect): - return "PREFIX:" + value - - def process_result_value(self, value, dialect): - return value[7:] - - def copy(self): - return MyType(self.impl.length) - - The class-level "impl" variable is required, and can reference any - TypeEngine class. Alternatively, the load_dialect_impl() method - can be used to provide different type classes based on the dialect - given; in this case, the "impl" variable can reference - ``TypeEngine`` as a placeholder. - - Types that receive a Python type that isn't similar to the - ultimate type used may want to define the :meth:`TypeDecorator.coerce_compared_value` - method. This is used to give the expression system a hint - when coercing Python objects into bind parameters within expressions. - Consider this expression:: - - mytable.c.somecol + datetime.date(2009, 5, 15) - - Above, if "somecol" is an ``Integer`` variant, it makes sense that - we're doing date arithmetic, where above is usually interpreted - by databases as adding a number of days to the given date. - The expression system does the right thing by not attempting to - coerce the "date()" value into an integer-oriented bind parameter. - - However, in the case of ``TypeDecorator``, we are usually changing - an incoming Python type to something new - ``TypeDecorator`` by - default will "coerce" the non-typed side to be the same type as itself. - Such as below, we define an "epoch" type that stores a date value as an integer:: - - class MyEpochType(types.TypeDecorator): - impl = types.Integer - - epoch = datetime.date(1970, 1, 1) - - def process_bind_param(self, value, dialect): - return (value - self.epoch).days - - def process_result_value(self, value, dialect): - return self.epoch + timedelta(days=value) - - Our expression of ``somecol + date`` with the above type will coerce the - "date" on the right side to also be treated as ``MyEpochType``. - - This behavior can be overridden via the :meth:`~TypeDecorator.coerce_compared_value` - method, which returns a type that should be used for the value of the expression. - Below we set it such that an integer value will be treated as an ``Integer``, - and any other value is assumed to be a date and will be treated as a ``MyEpochType``:: - - def coerce_compared_value(self, op, value): - if isinstance(value, int): - return Integer() - else: - return self - - """ - - __visit_name__ = "type_decorator" - - def __init__(self, *args, **kwargs): - if not hasattr(self.__class__, 'impl'): - raise AssertionError("TypeDecorator implementations require a class-level " - "variable 'impl' which refers to the class of type being decorated") - self.impl = self.__class__.impl(*args, **kwargs) - - def adapt(self, cls): - return cls() - - def dialect_impl(self, dialect): - key = (dialect.__class__, dialect.server_version_info) - try: - return self._impl_dict[key] - except KeyError: - pass - - # adapt the TypeDecorator first, in - # the case that the dialect maps the TD - # to one of its native types (i.e. PGInterval) - adapted = dialect.type_descriptor(self) - if adapted is not self: - self._impl_dict[key] = adapted - return adapted - - # otherwise adapt the impl type, link - # to a copy of this TypeDecorator and return - # that. - typedesc = self.load_dialect_impl(dialect) - tt = self.copy() - if not isinstance(tt, self.__class__): - raise AssertionError("Type object %s does not properly implement the copy() " - "method, it must return an object of type %s" % (self, self.__class__)) - tt.impl = typedesc - self._impl_dict[key] = tt - return tt - - @util.memoized_property - def _type_affinity(self): - return self.impl._type_affinity - - def type_engine(self, dialect): - impl = self.dialect_impl(dialect) - if not isinstance(impl, TypeDecorator): - return impl - else: - return impl.impl - - def load_dialect_impl(self, dialect): - """Loads the dialect-specific implementation of this type. - - by default calls dialect.type_descriptor(self.impl), but - can be overridden to provide different behavior. - - """ - if isinstance(self.impl, TypeDecorator): - return self.impl.dialect_impl(dialect) - else: - return dialect.type_descriptor(self.impl) - - def __getattr__(self, key): - """Proxy all other undefined accessors to the underlying implementation.""" - - return getattr(self.impl, key) - - def process_bind_param(self, value, dialect): - raise NotImplementedError() - - def process_result_value(self, value, dialect): - raise NotImplementedError() - - def bind_processor(self, dialect): - if self.__class__.process_bind_param.func_code is not TypeDecorator.process_bind_param.func_code: - process_param = self.process_bind_param - impl_processor = self.impl.bind_processor(dialect) - if impl_processor: - def process(value): - return impl_processor(process_param(value, dialect)) - else: - def process(value): - return process_param(value, dialect) - return process - else: - return self.impl.bind_processor(dialect) - - def result_processor(self, dialect, coltype): - if self.__class__.process_result_value.func_code is not TypeDecorator.process_result_value.func_code: - process_value = self.process_result_value - impl_processor = self.impl.result_processor(dialect, coltype) - if impl_processor: - def process(value): - return process_value(impl_processor(value), dialect) - else: - def process(value): - return process_value(value, dialect) - return process - else: - return self.impl.result_processor(dialect, coltype) - - def coerce_compared_value(self, op, value): - """Suggest a type for a 'coerced' Python value in an expression. - - By default, returns self. This method is called by - the expression system when an object using this type is - on the left or right side of an expression against a plain Python - object which does not yet have a SQLAlchemy type assigned:: - - expr = table.c.somecolumn + 35 - - Where above, if ``somecolumn`` uses this type, this method will - be called with the value ``operator.add`` - and ``35``. The return value is whatever SQLAlchemy type should - be used for ``35`` for this particular operation. - - """ - return self - - def _coerce_compared_value(self, op, value): - return self.coerce_compared_value(op, value) - - def copy(self): - instance = self.__class__.__new__(self.__class__) - instance.__dict__.update(self.__dict__) - instance._impl_dict = {} - return instance - - def get_dbapi_type(self, dbapi): - return self.impl.get_dbapi_type(dbapi) - - def copy_value(self, value): - return self.impl.copy_value(value) - - def compare_values(self, x, y): - return self.impl.compare_values(x, y) - - def is_mutable(self): - return self.impl.is_mutable() - - def _adapt_expression(self, op, othertype): - return self.impl._adapt_expression(op, othertype) - - - -class MutableType(object): - """A mixin that marks a Type as holding a mutable object. - - :meth:`copy_value` and :meth:`compare_values` should be customized - as needed to match the needs of the object. - - """ - - def is_mutable(self): - """Return True, mutable.""" - return True - - def copy_value(self, value): - """Unimplemented.""" - raise NotImplementedError() - - def compare_values(self, x, y): - """Compare *x* == *y*.""" - return x == y - -def to_instance(typeobj): - if typeobj is None: - return NULLTYPE - - if util.callable(typeobj): - return typeobj() - else: - return typeobj - -def adapt_type(typeobj, colspecs): - if isinstance(typeobj, type): - typeobj = typeobj() - for t in typeobj.__class__.__mro__[0:-1]: - try: - impltype = colspecs[t] - break - except KeyError: - pass - else: - # couldnt adapt - so just return the type itself - # (it may be a user-defined type) - return typeobj - # if we adapted the given generic type to a database-specific type, - # but it turns out the originally given "generic" type - # is actually a subclass of our resulting type, then we were already - # given a more specific type than that required; so use that. - if (issubclass(typeobj.__class__, impltype)): - return typeobj - return typeobj.adapt(impltype) - -class NullType(TypeEngine): - """An unknown type. - - NullTypes will stand in if :class:`~sqlalchemy.Table` reflection - encounters a column data type unknown to SQLAlchemy. The - resulting columns are nearly fully usable: the DB-API adapter will - handle all translation to and from the database data type. - - NullType does not have sufficient information to particpate in a - ``CREATE TABLE`` statement and will raise an exception if - encountered during a :meth:`~sqlalchemy.Table.create` operation. - - """ - __visit_name__ = 'null' - - def _adapt_expression(self, op, othertype): - if othertype is NullType or not operators.is_commutative(op): - return op, self - else: - return othertype._adapt_expression(op, self) - -NullTypeEngine = NullType - -class Concatenable(object): - """A mixin that marks a type as supporting 'concatenation', typically strings.""" - - def _adapt_expression(self, op, othertype): - if op is operators.add and issubclass(othertype._type_affinity, (Concatenable, NullType)): - return operators.concat_op, self - else: - return op, self - -class _DateAffinity(object): - """Mixin date/time specific expression adaptations. - - Rules are implemented within Date,Time,Interval,DateTime, Numeric, Integer. - Based on http://www.postgresql.org/docs/current/static/functions-datetime.html. - - """ - - @property - def _expression_adaptations(self): - raise NotImplementedError() - - _blank_dict = util.frozendict() - def _adapt_expression(self, op, othertype): - othertype = othertype._type_affinity - return op, \ - self._expression_adaptations.get(op, self._blank_dict).\ - get(othertype, NULLTYPE) - -class String(Concatenable, TypeEngine): - """The base for all string and character types. - - In SQL, corresponds to VARCHAR. Can also take Python unicode objects - and encode to the database's encoding in bind params (and the reverse for - result sets.) - - The `length` field is usually required when the `String` type is - used within a CREATE TABLE statement, as VARCHAR requires a length - on most databases. - - """ - - __visit_name__ = 'string' - - def __init__(self, length=None, convert_unicode=False, - assert_unicode=None, unicode_error=None, - _warn_on_bytestring=False - ): - """ - Create a string-holding type. - - :param length: optional, a length for the column for use in - DDL statements. May be safely omitted if no ``CREATE - TABLE`` will be issued. Certain databases may require a - *length* for use in DDL, and will raise an exception when - the ``CREATE TABLE`` DDL is issued. Whether the value is - interpreted as bytes or characters is database specific. - - :param convert_unicode: defaults to False. If True, the - type will do what is necessary in order to accept - Python Unicode objects as bind parameters, and to return - Python Unicode objects in result rows. This may - require SQLAlchemy to explicitly coerce incoming Python - unicodes into an encoding, and from an encoding - back to Unicode, or it may not require any interaction - from SQLAlchemy at all, depending on the DBAPI in use. - - When SQLAlchemy performs the encoding/decoding, - the encoding used is configured via - :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which - defaults to `utf-8`. - - The "convert_unicode" behavior can also be turned on - for all String types by setting - :attr:`sqlalchemy.engine.base.Dialect.convert_unicode` - on create_engine(). - - To instruct SQLAlchemy to perform Unicode encoding/decoding - even on a platform that already handles Unicode natively, - set convert_unicode='force'. This will incur significant - performance overhead when fetching unicode result columns. - - :param assert_unicode: Deprecated. A warning is raised in all cases when a non-Unicode - object is passed when SQLAlchemy would coerce into an encoding - (note: but **not** when the DBAPI handles unicode objects natively). - To suppress or raise this warning to an - error, use the Python warnings filter documented at: - http://docs.python.org/library/warnings.html - - :param unicode_error: Optional, a method to use to handle Unicode - conversion errors. Behaves like the 'errors' keyword argument to - the standard library's string.decode() functions. This flag - requires that `convert_unicode` is set to `"force"` - otherwise, - SQLAlchemy is not guaranteed to handle the task of unicode - conversion. Note that this flag adds significant performance - overhead to row-fetching operations for backends that already - return unicode objects natively (which most DBAPIs do). This - flag should only be used as an absolute last resort for reading - strings from a column with varied or corrupted encodings, - which only applies to databases that accept invalid encodings - in the first place (i.e. MySQL. *not* PG, Sqlite, etc.) - - """ - if unicode_error is not None and convert_unicode != 'force': - raise exc.ArgumentError("convert_unicode must be 'force' " - "when unicode_error is set.") - - if assert_unicode: - util.warn_deprecated("assert_unicode is deprecated. " - "SQLAlchemy emits a warning in all cases where it " - "would otherwise like to encode a Python unicode object " - "into a specific encoding but a plain bytestring is received. " - "This does *not* apply to DBAPIs that coerce Unicode natively." - ) - self.length = length - self.convert_unicode = convert_unicode - self.unicode_error = unicode_error - self._warn_on_bytestring = _warn_on_bytestring - - def adapt(self, impltype): - return impltype( - length=self.length, - convert_unicode=self.convert_unicode, - unicode_error=self.unicode_error, - _warn_on_bytestring=True, - ) - - def bind_processor(self, dialect): - if self.convert_unicode or dialect.convert_unicode: - if dialect.supports_unicode_binds and self.convert_unicode != 'force': - if self._warn_on_bytestring: - def process(value): - # Py3K - #if isinstance(value, bytes): - # Py2K - if isinstance(value, str): - # end Py2K - util.warn("Unicode type received non-unicode bind " - "param value %r" % value) - return value - return process - else: - return None - else: - encoder = codecs.getencoder(dialect.encoding) - def process(value): - if isinstance(value, unicode): - return encoder(value, self.unicode_error)[0] - elif value is not None: - util.warn("Unicode type received non-unicode bind " - "param value %r" % value) - return value - return process - else: - return None - - def result_processor(self, dialect, coltype): - wants_unicode = self.convert_unicode or dialect.convert_unicode - needs_convert = wants_unicode and \ - (dialect.returns_unicode_strings is not True or - self.convert_unicode == 'force') - - if needs_convert: - to_unicode = processors.to_unicode_processor_factory( - dialect.encoding, self.unicode_error) - - if dialect.returns_unicode_strings: - # we wouldn't be here unless convert_unicode='force' - # was specified, or the driver has erratic unicode-returning - # habits. since we will be getting back unicode - # in most cases, we check for it (decode will fail). - def process(value): - if isinstance(value, unicode): - return value - else: - return to_unicode(value) - return process - else: - # here, we assume that the object is not unicode, - # avoiding expensive isinstance() check. - return to_unicode - else: - return None - - def get_dbapi_type(self, dbapi): - return dbapi.STRING - -class Text(String): - """A variably sized string type. - - In SQL, usually corresponds to CLOB or TEXT. Can also take Python - unicode objects and encode to the database's encoding in bind - params (and the reverse for result sets.) - - """ - __visit_name__ = 'text' - -class Unicode(String): - """A variable length Unicode string. - - The ``Unicode`` type is a :class:`String` which converts Python - ``unicode`` objects (i.e., strings that are defined as - ``u'somevalue'``) into encoded bytestrings when passing the value - to the database driver, and similarly decodes values from the - database back into Python ``unicode`` objects. - - It's roughly equivalent to using a ``String`` object with - ``convert_unicode=True``, however - the type has other significances in that it implies the usage - of a unicode-capable type being used on the backend, such as NVARCHAR. - This may affect what type is emitted when issuing CREATE TABLE - and also may effect some DBAPI-specific details, such as type - information passed along to ``setinputsizes()``. - - When using the ``Unicode`` type, it is only appropriate to pass - Python ``unicode`` objects, and not plain ``str``. If a - bytestring (``str``) is passed, a runtime warning is issued. If - you notice your application raising these warnings but you're not - sure where, the Python ``warnings`` filter can be used to turn - these warnings into exceptions which will illustrate a stack - trace:: - - import warnings - warnings.simplefilter('error') - - Bytestrings sent to and received from the database are encoded - using the dialect's - :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which defaults - to `utf-8`. - - """ - - __visit_name__ = 'unicode' - - def __init__(self, length=None, **kwargs): - """ - Create a Unicode-converting String type. - - :param length: optional, a length for the column for use in - DDL statements. May be safely omitted if no ``CREATE - TABLE`` will be issued. Certain databases may require a - *length* for use in DDL, and will raise an exception when - the ``CREATE TABLE`` DDL is issued. Whether the value is - interpreted as bytes or characters is database specific. - - :param \**kwargs: passed through to the underlying ``String`` - type. - - """ - kwargs.setdefault('convert_unicode', True) - kwargs.setdefault('_warn_on_bytestring', True) - super(Unicode, self).__init__(length=length, **kwargs) - -class UnicodeText(Text): - """An unbounded-length Unicode string. - - See :class:`Unicode` for details on the unicode - behavior of this object. - - Like ``Unicode``, usage the ``UnicodeText`` type implies a - unicode-capable type being used on the backend, such as NCLOB. - - """ - - __visit_name__ = 'unicode_text' - - def __init__(self, length=None, **kwargs): - """ - Create a Unicode-converting Text type. - - :param length: optional, a length for the column for use in - DDL statements. May be safely omitted if no ``CREATE - TABLE`` will be issued. Certain databases may require a - *length* for use in DDL, and will raise an exception when - the ``CREATE TABLE`` DDL is issued. Whether the value is - interpreted as bytes or characters is database specific. - - """ - kwargs.setdefault('convert_unicode', True) - kwargs.setdefault('_warn_on_bytestring', True) - super(UnicodeText, self).__init__(length=length, **kwargs) - - -class Integer(_DateAffinity, TypeEngine): - """A type for ``int`` integers.""" - - __visit_name__ = 'integer' - - def get_dbapi_type(self, dbapi): - return dbapi.NUMBER - - @util.memoized_property - def _expression_adaptations(self): - # TODO: need a dictionary object that will - # handle operators generically here, this is incomplete - return { - operators.add:{ - Date:Date, - Integer:Integer, - Numeric:Numeric, - }, - operators.mul:{ - Interval:Interval, - Integer:Integer, - Numeric:Numeric, - }, - # Py2K - operators.div:{ - Integer:Integer, - Numeric:Numeric, - }, - # end Py2K - operators.truediv:{ - Integer:Integer, - Numeric:Numeric, - }, - operators.sub:{ - Integer:Integer, - Numeric:Numeric, - }, - } - -class SmallInteger(Integer): - """A type for smaller ``int`` integers. - - Typically generates a ``SMALLINT`` in DDL, and otherwise acts like - a normal :class:`Integer` on the Python side. - - """ - - __visit_name__ = 'small_integer' - -class BigInteger(Integer): - """A type for bigger ``int`` integers. - - Typically generates a ``BIGINT`` in DDL, and otherwise acts like - a normal :class:`Integer` on the Python side. - - """ - - __visit_name__ = 'big_integer' - -class Numeric(_DateAffinity, TypeEngine): - """A type for fixed precision numbers. - - Typically generates DECIMAL or NUMERIC. Returns - ``decimal.Decimal`` objects by default, applying - conversion as needed. - - """ - - __visit_name__ = 'numeric' - - def __init__(self, precision=None, scale=None, asdecimal=True): - """ - Construct a Numeric. - - :param precision: the numeric precision for use in DDL ``CREATE TABLE``. - - :param scale: the numeric scale for use in DDL ``CREATE TABLE``. - - :param asdecimal: default True. Return whether or not - values should be sent as Python Decimal objects, or - as floats. Different DBAPIs send one or the other based on - datatypes - the Numeric type will ensure that return values - are one or the other across DBAPIs consistently. - - When using the ``Numeric`` type, care should be taken to ensure - that the asdecimal setting is apppropriate for the DBAPI in use - - when Numeric applies a conversion from Decimal->float or float-> - Decimal, this conversion incurs an additional performance overhead - for all result columns received. - - DBAPIs that return Decimal natively (e.g. psycopg2) will have - better accuracy and higher performance with a setting of ``True``, - as the native translation to Decimal reduces the amount of floating- - point issues at play, and the Numeric type itself doesn't need - to apply any further conversions. However, another DBAPI which - returns floats natively *will* incur an additional conversion - overhead, and is still subject to floating point data loss - in - which case ``asdecimal=False`` will at least remove the extra - conversion overhead. - - """ - self.precision = precision - self.scale = scale - self.asdecimal = asdecimal - - def adapt(self, impltype): - return impltype( - precision=self.precision, - scale=self.scale, - asdecimal=self.asdecimal) - - def get_dbapi_type(self, dbapi): - return dbapi.NUMBER - - def bind_processor(self, dialect): - if dialect.supports_native_decimal: - return None - else: - return processors.to_float - - def result_processor(self, dialect, coltype): - if self.asdecimal: - if dialect.supports_native_decimal: - # we're a "numeric", DBAPI will give us Decimal directly - return None - else: - # we're a "numeric", DBAPI returns floats, convert. - if self.scale is not None: - return processors.to_decimal_processor_factory(_python_Decimal, self.scale) - else: - return processors.to_decimal_processor_factory(_python_Decimal) - else: - if dialect.supports_native_decimal: - return processors.to_float - else: - return None - - @util.memoized_property - def _expression_adaptations(self): - return { - operators.mul:{ - Interval:Interval - }, - } - - -class Float(Numeric): - """A type for ``float`` numbers. - - Returns Python ``float`` objects by default, applying - conversion as needed. - - """ - - __visit_name__ = 'float' - - def __init__(self, precision=None, asdecimal=False, **kwargs): - """ - Construct a Float. - - :param precision: the numeric precision for use in DDL ``CREATE TABLE``. - - :param asdecimal: the same flag as that of :class:`Numeric`, but - defaults to ``False``. - - """ - self.precision = precision - self.asdecimal = asdecimal - - def adapt(self, impltype): - return impltype(precision=self.precision, asdecimal=self.asdecimal) - - def result_processor(self, dialect, coltype): - if self.asdecimal: - return processors.to_decimal_processor_factory(_python_Decimal) - else: - return None - - -class DateTime(_DateAffinity, TypeEngine): - """A type for ``datetime.datetime()`` objects. - - Date and time types return objects from the Python ``datetime`` - module. Most DBAPIs have built in support for the datetime - module, with the noted exception of SQLite. In the case of - SQLite, date and time types are stored as strings which are then - converted back to datetime objects when rows are returned. - - """ - - __visit_name__ = 'datetime' - - def __init__(self, timezone=False): - self.timezone = timezone - - def adapt(self, impltype): - return impltype(timezone=self.timezone) - - def get_dbapi_type(self, dbapi): - return dbapi.DATETIME - - @util.memoized_property - def _expression_adaptations(self): - return { - operators.add:{ - Interval:DateTime, - }, - operators.sub:{ - Interval:DateTime, - DateTime:Interval, - }, - } - - -class Date(_DateAffinity,TypeEngine): - """A type for ``datetime.date()`` objects.""" - - __visit_name__ = 'date' - - def get_dbapi_type(self, dbapi): - return dbapi.DATETIME - - @util.memoized_property - def _expression_adaptations(self): - return { - operators.add:{ - Integer:Date, - Interval:DateTime, - Time:DateTime, - }, - operators.sub:{ - # date - integer = date - Integer:Date, - - # date - date = integer. - Date:Integer, - - Interval:DateTime, - - # date - datetime = interval, - # this one is not in the PG docs - # but works - DateTime:Interval, - }, - } - - -class Time(_DateAffinity,TypeEngine): - """A type for ``datetime.time()`` objects.""" - - __visit_name__ = 'time' - - def __init__(self, timezone=False): - self.timezone = timezone - - def adapt(self, impltype): - return impltype(timezone=self.timezone) - - def get_dbapi_type(self, dbapi): - return dbapi.DATETIME - - @util.memoized_property - def _expression_adaptations(self): - return { - operators.add:{ - Date:DateTime, - Interval:Time - }, - operators.sub:{ - Time:Interval, - Interval:Time, - }, - } - - -class _Binary(TypeEngine): - """Define base behavior for binary types.""" - - def __init__(self, length=None): - self.length = length - - # Python 3 - sqlite3 doesn't need the `Binary` conversion - # here, though pg8000 does to indicate "bytea" - def bind_processor(self, dialect): - DBAPIBinary = dialect.dbapi.Binary - def process(value): - if value is not None: - return DBAPIBinary(value) - else: - return None - return process - - # Python 3 has native bytes() type - # both sqlite3 and pg8000 seem to return it - # (i.e. and not 'memoryview') - # Py2K - def result_processor(self, dialect, coltype): - if util.jython: - def process(value): - if value is not None: - if isinstance(value, array.array): - return value.tostring() - return str(value) - else: - return None - else: - process = processors.to_str - return process - # end Py2K - - def adapt(self, impltype): - return impltype(length=self.length) - - def get_dbapi_type(self, dbapi): - return dbapi.BINARY - -class LargeBinary(_Binary): - """A type for large binary byte data. - - The Binary type generates BLOB or BYTEA when tables are created, - and also converts incoming values using the ``Binary`` callable - provided by each DB-API. - - """ - - __visit_name__ = 'large_binary' - - def __init__(self, length=None): - """ - Construct a LargeBinary type. - - :param length: optional, a length for the column for use in - DDL statements, for those BLOB types that accept a length - (i.e. MySQL). It does *not* produce a small BINARY/VARBINARY - type - use the BINARY/VARBINARY types specifically for those. - May be safely omitted if no ``CREATE - TABLE`` will be issued. Certain databases may require a - *length* for use in DDL, and will raise an exception when - the ``CREATE TABLE`` DDL is issued. - - """ - _Binary.__init__(self, length=length) - -class Binary(LargeBinary): - """Deprecated. Renamed to LargeBinary.""" - - def __init__(self, *arg, **kw): - util.warn_deprecated("The Binary type has been renamed to LargeBinary.") - LargeBinary.__init__(self, *arg, **kw) - -class SchemaType(object): - """Mark a type as possibly requiring schema-level DDL for usage. - - Supports types that must be explicitly created/dropped (i.e. PG ENUM type) - as well as types that are complimented by table or schema level - constraints, triggers, and other rules. - - """ - - def __init__(self, **kw): - self.name = kw.pop('name', None) - self.quote = kw.pop('quote', None) - self.schema = kw.pop('schema', None) - self.metadata = kw.pop('metadata', None) - if self.metadata: - self.metadata.append_ddl_listener( - 'before-create', - util.portable_instancemethod(self._on_metadata_create) - ) - self.metadata.append_ddl_listener( - 'after-drop', - util.portable_instancemethod(self._on_metadata_drop) - ) - - def _set_parent(self, column): - column._on_table_attach(util.portable_instancemethod(self._set_table)) - - def _set_table(self, table, column): - table.append_ddl_listener( - 'before-create', - util.portable_instancemethod(self._on_table_create) - ) - table.append_ddl_listener( - 'after-drop', - util.portable_instancemethod(self._on_table_drop) - ) - if self.metadata is None: - table.metadata.append_ddl_listener( - 'before-create', - util.portable_instancemethod(self._on_metadata_create) - ) - table.metadata.append_ddl_listener( - 'after-drop', - util.portable_instancemethod(self._on_metadata_drop) - ) - - @property - def bind(self): - return self.metadata and self.metadata.bind or None - - def create(self, bind=None, checkfirst=False): - """Issue CREATE ddl for this type, if applicable.""" - - from sqlalchemy.schema import _bind_or_error - if bind is None: - bind = _bind_or_error(self) - t = self.dialect_impl(bind.dialect) - if t is not self and isinstance(t, SchemaType): - t.create(bind=bind, checkfirst=checkfirst) - - def drop(self, bind=None, checkfirst=False): - """Issue DROP ddl for this type, if applicable.""" - - from sqlalchemy.schema import _bind_or_error - if bind is None: - bind = _bind_or_error(self) - t = self.dialect_impl(bind.dialect) - if t is not self and isinstance(t, SchemaType): - t.drop(bind=bind, checkfirst=checkfirst) - - def _on_table_create(self, event, target, bind, **kw): - t = self.dialect_impl(bind.dialect) - if t is not self and isinstance(t, SchemaType): - t._on_table_create(event, target, bind, **kw) - - def _on_table_drop(self, event, target, bind, **kw): - t = self.dialect_impl(bind.dialect) - if t is not self and isinstance(t, SchemaType): - t._on_table_drop(event, target, bind, **kw) - - def _on_metadata_create(self, event, target, bind, **kw): - t = self.dialect_impl(bind.dialect) - if t is not self and isinstance(t, SchemaType): - t._on_metadata_create(event, target, bind, **kw) - - def _on_metadata_drop(self, event, target, bind, **kw): - t = self.dialect_impl(bind.dialect) - if t is not self and isinstance(t, SchemaType): - t._on_metadata_drop(event, target, bind, **kw) - -class Enum(String, SchemaType): - """Generic Enum Type. - - The Enum type provides a set of possible string values which the - column is constrained towards. - - By default, uses the backend's native ENUM type if available, - else uses VARCHAR + a CHECK constraint. - """ - - __visit_name__ = 'enum' - - def __init__(self, *enums, **kw): - """Construct an enum. - - Keyword arguments which don't apply to a specific backend are ignored - by that backend. - - :param \*enums: string or unicode enumeration labels. If unicode labels - are present, the `convert_unicode` flag is auto-enabled. - - :param convert_unicode: Enable unicode-aware bind parameter and result-set - processing for this Enum's data. This is set automatically based on - the presence of unicode label strings. - - :param metadata: Associate this type directly with a ``MetaData`` object. - For types that exist on the target database as an independent schema - construct (Postgresql), this type will be created and dropped within - ``create_all()`` and ``drop_all()`` operations. If the type is not - associated with any ``MetaData`` object, it will associate itself with - each ``Table`` in which it is used, and will be created when any of - those individual tables are created, after a check is performed for - it's existence. The type is only dropped when ``drop_all()`` is called - for that ``Table`` object's metadata, however. - - :param name: The name of this type. This is required for Postgresql and - any future supported database which requires an explicitly named type, - or an explicitly named constraint in order to generate the type and/or - a table that uses it. - - :param native_enum: Use the database's native ENUM type when available. - Defaults to True. When False, uses VARCHAR + check constraint - for all backends. - - :param schema: Schemaname of this type. For types that exist on the target - database as an independent schema construct (Postgresql), this - parameter specifies the named schema in which the type is present. - - :param quote: Force quoting to be on or off on the type's name. If left as - the default of `None`, the usual schema-level "case - sensitive"/"reserved name" rules are used to determine if this type's - name should be quoted. - - """ - self.enums = enums - self.native_enum = kw.pop('native_enum', True) - convert_unicode= kw.pop('convert_unicode', None) - if convert_unicode is None: - for e in enums: - if isinstance(e, unicode): - convert_unicode = True - break - else: - convert_unicode = False - - if self.enums: - length =max(len(x) for x in self.enums) - else: - length = 0 - String.__init__(self, - length =length, - convert_unicode=convert_unicode, - ) - SchemaType.__init__(self, **kw) - - def _should_create_constraint(self, compiler): - return not self.native_enum or \ - not compiler.dialect.supports_native_enum - - def _set_table(self, table, column): - if self.native_enum: - SchemaType._set_table(self, table, column) - - - e = schema.CheckConstraint( - column.in_(self.enums), - name=self.name, - _create_rule=util.portable_instancemethod(self._should_create_constraint) - ) - table.append_constraint(e) - - def adapt(self, impltype): - if issubclass(impltype, Enum): - return impltype(name=self.name, - quote=self.quote, - schema=self.schema, - metadata=self.metadata, - convert_unicode=self.convert_unicode, - *self.enums - ) - else: - return super(Enum, self).adapt(impltype) - -class PickleType(MutableType, TypeDecorator): - """Holds Python objects. - - PickleType builds upon the Binary type to apply Python's - ``pickle.dumps()`` to incoming objects, and ``pickle.loads()`` on - the way out, allowing any pickleable Python object to be stored as - a serialized binary field. - - """ - - impl = LargeBinary - - def __init__(self, protocol=pickle.HIGHEST_PROTOCOL, pickler=None, mutable=True, comparator=None): - """ - Construct a PickleType. - - :param protocol: defaults to ``pickle.HIGHEST_PROTOCOL``. - - :param pickler: defaults to cPickle.pickle or pickle.pickle if - cPickle is not available. May be any object with - pickle-compatible ``dumps` and ``loads`` methods. - - :param mutable: defaults to True; implements - :meth:`AbstractType.is_mutable`. When ``True``, incoming - objects should provide an ``__eq__()`` method which - performs the desired deep comparison of members, or the - ``comparator`` argument must be present. - - :param comparator: optional. a 2-arg callable predicate used - to compare values of this type. Otherwise, - the == operator is used to compare values. - - """ - self.protocol = protocol - self.pickler = pickler or pickle - self.mutable = mutable - self.comparator = comparator - super(PickleType, self).__init__() - - def bind_processor(self, dialect): - impl_processor = self.impl.bind_processor(dialect) - dumps = self.pickler.dumps - protocol = self.protocol - if impl_processor: - def process(value): - if value is not None: - value = dumps(value, protocol) - return impl_processor(value) - else: - def process(value): - if value is not None: - value = dumps(value, protocol) - return value - return process - - def result_processor(self, dialect, coltype): - impl_processor = self.impl.result_processor(dialect, coltype) - loads = self.pickler.loads - if impl_processor: - def process(value): - value = impl_processor(value) - if value is None: - return None - return loads(value) - else: - def process(value): - if value is None: - return None - return loads(value) - return process - - def copy_value(self, value): - if self.mutable: - return self.pickler.loads(self.pickler.dumps(value, self.protocol)) - else: - return value - - def compare_values(self, x, y): - if self.comparator: - return self.comparator(x, y) - else: - return x == y - - def is_mutable(self): - return self.mutable - - -class Boolean(TypeEngine, SchemaType): - """A bool datatype. - - Boolean typically uses BOOLEAN or SMALLINT on the DDL side, and on - the Python side deals in ``True`` or ``False``. - - """ - - __visit_name__ = 'boolean' - - def __init__(self, create_constraint=True, name=None): - """Construct a Boolean. - - :param create_constraint: defaults to True. If the boolean - is generated as an int/smallint, also create a CHECK constraint - on the table that ensures 1 or 0 as a value. - - :param name: if a CHECK constraint is generated, specify - the name of the constraint. - - """ - self.create_constraint = create_constraint - self.name = name - - def _should_create_constraint(self, compiler): - return not compiler.dialect.supports_native_boolean - - def _set_table(self, table, column): - if not self.create_constraint: - return - - e = schema.CheckConstraint( - column.in_([0, 1]), - name=self.name, - _create_rule=util.portable_instancemethod(self._should_create_constraint) - ) - table.append_constraint(e) - - def result_processor(self, dialect, coltype): - if dialect.supports_native_boolean: - return None - else: - return processors.int_to_boolean - -class Interval(_DateAffinity, TypeDecorator): - """A type for ``datetime.timedelta()`` objects. - - The Interval type deals with ``datetime.timedelta`` objects. In - PostgreSQL, the native ``INTERVAL`` type is used; for others, the - value is stored as a date which is relative to the "epoch" - (Jan. 1, 1970). - - Note that the ``Interval`` type does not currently provide - date arithmetic operations on platforms which do not support - interval types natively. Such operations usually require - transformation of both sides of the expression (such as, conversion - of both sides into integer epoch values first) which currently - is a manual procedure (such as via :attr:`~sqlalchemy.sql.expression.func`). - - """ - - impl = DateTime - epoch = dt.datetime.utcfromtimestamp(0) - - def __init__(self, native=True, - second_precision=None, - day_precision=None): - """Construct an Interval object. - - :param native: when True, use the actual - INTERVAL type provided by the database, if - supported (currently Postgresql, Oracle). - Otherwise, represent the interval data as - an epoch value regardless. - - :param second_precision: For native interval types - which support a "fractional seconds precision" parameter, - i.e. Oracle and Postgresql - - :param day_precision: for native interval types which - support a "day precision" parameter, i.e. Oracle. - - """ - super(Interval, self).__init__() - self.native = native - self.second_precision = second_precision - self.day_precision = day_precision - - def adapt(self, cls): - if self.native: - return cls._adapt_from_generic_interval(self) - else: - return self - - def bind_processor(self, dialect): - impl_processor = self.impl.bind_processor(dialect) - epoch = self.epoch - if impl_processor: - def process(value): - if value is not None: - value = epoch + value - return impl_processor(value) - else: - def process(value): - if value is not None: - value = epoch + value - return value - return process - - def result_processor(self, dialect, coltype): - impl_processor = self.impl.result_processor(dialect, coltype) - epoch = self.epoch - if impl_processor: - def process(value): - value = impl_processor(value) - if value is None: - return None - return value - epoch - else: - def process(value): - if value is None: - return None - return value - epoch - return process - - @util.memoized_property - def _expression_adaptations(self): - return { - operators.add:{ - Date:DateTime, - Interval:Interval, - DateTime:DateTime, - Time:Time, - }, - operators.sub:{ - Interval:Interval - }, - operators.mul:{ - Numeric:Interval - }, - operators.truediv: { - Numeric:Interval - }, - # Py2K - operators.div: { - Numeric:Interval - } - # end Py2K - } - - @property - def _type_affinity(self): - return Interval - - def _coerce_compared_value(self, op, value): - return self.impl._coerce_compared_value(op, value) - - -class FLOAT(Float): - """The SQL FLOAT type.""" - - __visit_name__ = 'FLOAT' - -class NUMERIC(Numeric): - """The SQL NUMERIC type.""" - - __visit_name__ = 'NUMERIC' - - -class DECIMAL(Numeric): - """The SQL DECIMAL type.""" - - __visit_name__ = 'DECIMAL' - - -class INTEGER(Integer): - """The SQL INT or INTEGER type.""" - - __visit_name__ = 'INTEGER' -INT = INTEGER - - -class SMALLINT(SmallInteger): - """The SQL SMALLINT type.""" - - __visit_name__ = 'SMALLINT' - - -class BIGINT(BigInteger): - """The SQL BIGINT type.""" - - __visit_name__ = 'BIGINT' - -class TIMESTAMP(DateTime): - """The SQL TIMESTAMP type.""" - - __visit_name__ = 'TIMESTAMP' - - def get_dbapi_type(self, dbapi): - return dbapi.TIMESTAMP - -class DATETIME(DateTime): - """The SQL DATETIME type.""" - - __visit_name__ = 'DATETIME' - - -class DATE(Date): - """The SQL DATE type.""" - - __visit_name__ = 'DATE' - - -class TIME(Time): - """The SQL TIME type.""" - - __visit_name__ = 'TIME' - -class TEXT(Text): - """The SQL TEXT type.""" - - __visit_name__ = 'TEXT' - -class CLOB(Text): - """The CLOB type. - - This type is found in Oracle and Informix. - """ - - __visit_name__ = 'CLOB' - -class VARCHAR(String): - """The SQL VARCHAR type.""" - - __visit_name__ = 'VARCHAR' - -class NVARCHAR(Unicode): - """The SQL NVARCHAR type.""" - - __visit_name__ = 'NVARCHAR' - -class CHAR(String): - """The SQL CHAR type.""" - - __visit_name__ = 'CHAR' - - -class NCHAR(Unicode): - """The SQL NCHAR type.""" - - __visit_name__ = 'NCHAR' - - -class BLOB(LargeBinary): - """The SQL BLOB type.""" - - __visit_name__ = 'BLOB' - -class BINARY(_Binary): - """The SQL BINARY type.""" - - __visit_name__ = 'BINARY' - -class VARBINARY(_Binary): - """The SQL VARBINARY type.""" - - __visit_name__ = 'VARBINARY' - - -class BOOLEAN(Boolean): - """The SQL BOOLEAN type.""" - - __visit_name__ = 'BOOLEAN' - -NULLTYPE = NullType() -BOOLEANTYPE = Boolean() - -# using VARCHAR/NCHAR so that we dont get the genericized "String" -# type which usually resolves to TEXT/CLOB -type_map = { - str: String(), - # Py3K - #bytes : LargeBinary(), - # Py2K - unicode : Unicode(), - # end Py2K - int : Integer(), - float : Numeric(), - bool: BOOLEANTYPE, - _python_Decimal : Numeric(), - dt.date : Date(), - dt.datetime : DateTime(), - dt.time : Time(), - dt.timedelta : Interval(), - NoneType: NULLTYPE -} +__all__ = ['TypeEngine', 'TypeDecorator', 'UserDefinedType', + 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR', 'TEXT', 'Text', + 'FLOAT', 'NUMERIC', 'REAL', 'DECIMAL', 'TIMESTAMP', 'DATETIME', + 'CLOB', 'BLOB', 'BINARY', 'VARBINARY', 'BOOLEAN', 'BIGINT', + 'SMALLINT', 'INTEGER', 'DATE', 'TIME', 'String', 'Integer', + 'SmallInteger', 'BigInteger', 'Numeric', 'Float', 'DateTime', + 'Date', 'Time', 'LargeBinary', 'Binary', 'Boolean', 'Unicode', + 'Concatenable', 'UnicodeText', 'PickleType', 'Interval', 'Enum', + 'Indexable', 'ARRAY', 'JSON'] + +from .sql.type_api import ( + adapt_type, + TypeEngine, + TypeDecorator, + Variant, + to_instance, + UserDefinedType +) +from .sql.sqltypes import ( + ARRAY, + BIGINT, + BINARY, + BLOB, + BOOLEAN, + BigInteger, + Binary, + _Binary, + Boolean, + CHAR, + CLOB, + Concatenable, + DATE, + DATETIME, + DECIMAL, + Date, + DateTime, + Enum, + FLOAT, + Float, + Indexable, + INT, + INTEGER, + Integer, + Interval, + JSON, + LargeBinary, + MatchType, + NCHAR, + NVARCHAR, + NullType, + NULLTYPE, + NUMERIC, + Numeric, + PickleType, + REAL, + SchemaType, + SMALLINT, + SmallInteger, + String, + STRINGTYPE, + TEXT, + TIME, + TIMESTAMP, + Text, + Time, + Unicode, + UnicodeText, + VARBINARY, + VARCHAR, + )